Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/pmpalang/pyg-lib
Browse files Browse the repository at this point in the history
  • Loading branch information
pmpalang committed Nov 14, 2023
2 parents 3c31852 + d99233a commit 7c666d4
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 88 deletions.
138 changes: 71 additions & 67 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,25 @@ class NeighborSampler {
}

void edge_temporal_sample(const node_t global_src_node,
const scalar_t local_src_node,
const temporal_t* edge_times,
const int64_t count,
const temporal_t seed_time,
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
const scalar_t local_src_node,
const temporal_t* edge_times,
const int64_t count,
const temporal_t seed_time,
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
auto row_start = rowptr_[to_scalar_t(global_src_node)];
auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];

if ((row_end - row_start == 0) || (count == 0))
return;
// Find new `row_end` such that all neighbors fulfill temporal constraints:
std::vector<int> v(row_end-row_start);
std::iota (std::begin(v), std::end(v), row_start);
auto it = std::upper_bound(
v.begin(), v.end(), seed_time,
[&](const scalar_t& a, const scalar_t& b) { return a < edge_times[b]; });
std::vector<int> v(row_end - row_start);
std::iota(std::begin(v), std::end(v), row_start);
auto it = std::upper_bound(v.begin(), v.end(), seed_time,
[&](const scalar_t& a, const scalar_t& b) {
return a < edge_times[b];
});
row_end = it - v.begin() + row_start;

if (temporal_strategy_ == "last" && count >= 0) {
Expand All @@ -103,20 +104,18 @@ class NeighborSampler {
"Found invalid non-sorted temporal neighborhood");
}


_sample(global_src_node, local_src_node, row_start, row_end, count,
dst_mapper, generator, out_global_dst_nodes);
}



void node_temporal_sample(const node_t global_src_node,
const scalar_t local_src_node,
const int64_t count,
const temporal_t seed_time,
const temporal_t* time,
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
const scalar_t local_src_node,
const int64_t count,
const temporal_t seed_time,
const temporal_t* time,
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
auto row_start = rowptr_[to_scalar_t(global_src_node)];
auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];

Expand Down Expand Up @@ -360,10 +359,12 @@ sample(const at::Tensor& rowptr,
TORCH_CHECK(col.is_contiguous(), "Non-contiguous 'col'");
TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'");
if (node_time.has_value()) {
TORCH_CHECK(node_time.value().is_contiguous(), "Non-contiguous 'node_time'");
TORCH_CHECK(node_time.value().is_contiguous(),
"Non-contiguous 'node_time'");
}
if (edge_time.has_value()) {
TORCH_CHECK(edge_time.value().is_contiguous(), "Non-contiguous 'edge_time'");
TORCH_CHECK(edge_time.value().is_contiguous(),
"Non-contiguous 'edge_time'");
}
if (seed_time.has_value()) {
TORCH_CHECK(seed_time.value().is_contiguous(),
Expand Down Expand Up @@ -461,7 +462,7 @@ sample(const at::Tensor& rowptr,
const auto batch_idx = sampled_nodes[i].first;
sampler.edge_temporal_sample(
/*global_src_node=*/sampled_nodes[i],
/*local_src_node=*/i,
/*local_src_node=*/i,
/*edge_time=*/edge_time_data,
/*count=*/count,
/*seed_time=*/seed_times[batch_idx],
Expand All @@ -471,8 +472,7 @@ sample(const at::Tensor& rowptr,
if constexpr (distributed)
cumsum_neighbors_per_node.push_back(sampled_nodes.size());
}
}
else {
} else {
const auto time_data = node_time.value().data_ptr<temporal_t>();
for (size_t i = begin; i < end; ++i) {
const auto batch_idx = sampled_nodes[i].first;
Expand Down Expand Up @@ -690,7 +690,6 @@ sample(const std::vector<node_type>& node_types,
// Not supported
exit(0);
}

}

num_sampled_nodes_per_hop_map.at(kv.key())[0] =
Expand Down Expand Up @@ -738,8 +737,10 @@ sample(const std::vector<node_type>& node_types,
/*generator=*/generator,
/*out_global_dst_nodes=*/dst_sampled_nodes);
}
} else if ((!node_time_dict.has_value() || !node_time_dict.value().contains(dst)) &&
(!edge_time_dict.has_value() || !edge_time_dict.value().contains(to_rel_type(k)))) {
} else if ((!node_time_dict.has_value() ||
!node_time_dict.value().contains(dst)) &&
(!edge_time_dict.has_value() ||
!edge_time_dict.value().contains(to_rel_type(k)))) {
for (size_t i = begin; i < end; ++i) {
sampler.uniform_sample(
/*global_src_node=*/src_sampled_nodes[i],
Expand All @@ -749,41 +750,44 @@ sample(const std::vector<node_type>& node_types,
/*generator=*/generator,
/*out_global_dst_nodes=*/dst_sampled_nodes);
}
} else if constexpr (!std::is_scalar<node_t>::value){
if (edge_time_dict.has_value() && edge_time_dict.value().contains(to_rel_type(k))) {
//Edge temporal sampling
const at::Tensor& edge_times = edge_time_dict.value().at(to_rel_type(k));
const auto edge_times_data = edge_times.data_ptr<temporal_t>();
for (size_t i = begin; i < end; ++i) {
const auto batch_idx = src_sampled_nodes[i].first;
sampler.edge_temporal_sample(
/*global_src_node=*/src_sampled_nodes[i],
/*local_src_node=*/i,
/*edge_times=*/edge_times_data,
/*count=*/count,
/*seed_time=*/seed_times[batch_idx],
/*dst_mapper=*/dst_mapper,
/*generator=*/generator,
/*out_global_dst_nodes=*/dst_sampled_nodes);
}
} else {
// Temporal sampling:
const at::Tensor& dst_time = node_time_dict.value().at(dst);
const auto dst_time_data = dst_time.data_ptr<temporal_t>();
for (size_t i = begin; i < end; ++i) {
const auto batch_idx = src_sampled_nodes[i].first;
sampler.node_temporal_sample(
/*global_src_node=*/src_sampled_nodes[i],
/*local_src_node=*/i,
/*count=*/count,
/*seed_time=*/seed_times[batch_idx],
/*time=*/dst_time_data,
/*dst_mapper=*/dst_mapper,
/*generator=*/generator,
/*out_global_dst_nodes=*/dst_sampled_nodes);
} else if constexpr (!std::is_scalar<node_t>::value) {
if (edge_time_dict.has_value() &&
edge_time_dict.value().contains(to_rel_type(k))) {
// Edge temporal sampling
const at::Tensor& edge_times =
edge_time_dict.value().at(to_rel_type(k));
const auto edge_times_data =
edge_times.data_ptr<temporal_t>();
for (size_t i = begin; i < end; ++i) {
const auto batch_idx = src_sampled_nodes[i].first;
sampler.edge_temporal_sample(
/*global_src_node=*/src_sampled_nodes[i],
/*local_src_node=*/i,
/*edge_times=*/edge_times_data,
/*count=*/count,
/*seed_time=*/seed_times[batch_idx],
/*dst_mapper=*/dst_mapper,
/*generator=*/generator,
/*out_global_dst_nodes=*/dst_sampled_nodes);
}
} else {
// Temporal sampling:
const at::Tensor& dst_time = node_time_dict.value().at(dst);
const auto dst_time_data = dst_time.data_ptr<temporal_t>();
for (size_t i = begin; i < end; ++i) {
const auto batch_idx = src_sampled_nodes[i].first;
sampler.node_temporal_sample(
/*global_src_node=*/src_sampled_nodes[i],
/*local_src_node=*/i,
/*count=*/count,
/*seed_time=*/seed_times[batch_idx],
/*time=*/dst_time_data,
/*dst_mapper=*/dst_mapper,
/*generator=*/generator,
/*out_global_dst_nodes=*/dst_sampled_nodes);
}
}
}
}
}
}
});
Expand Down Expand Up @@ -908,8 +912,8 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
bool return_edge_id) {
const auto out = [&] {
DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col,
seed, num_neighbors, node_time, seed_time, edge_time, edge_weight, csc,
temporal_strategy);
seed, num_neighbors, node_time, seed_time, edge_time,
edge_weight, csc, temporal_strategy);
}();
return std::make_tuple(std::get<0>(out), std::get<1>(out), std::get<2>(out),
std::get<3>(out), std::get<4>(out), std::get<5>(out));
Expand Down Expand Up @@ -960,8 +964,8 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr,
std::string temporal_strategy) {
const auto out = [&] {
DISPATCH_DIST_SAMPLE(replace, directed, disjoint, rowptr, col, seed,
{num_neighbors}, node_time, seed_time, edge_time, edge_weight, csc,
temporal_strategy);
{num_neighbors}, node_time, seed_time, edge_time,
edge_weight, csc, temporal_strategy);
}();
return std::make_tuple(std::get<2>(out), std::get<3>(out).value(),
std::get<6>(out));
Expand Down
21 changes: 12 additions & 9 deletions pyg_lib/csrc/sampler/neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ neighbor_sample(const at::Tensor& rowptr,
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::neighbor_sample", "")
.typed<decltype(neighbor_sample)>();
return op.call(rowptr, col, seed, num_neighbors, node_time, seed_time, edge_time, edge_weight,
csc, replace, directed, disjoint, temporal_strategy,
return_edge_id);
return op.call(rowptr, col, seed, num_neighbors, node_time, seed_time,
edge_time, edge_weight, csc, replace, directed, disjoint,
temporal_strategy, return_edge_id);
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
Expand Down Expand Up @@ -92,8 +92,8 @@ hetero_neighbor_sample(
.typed<decltype(hetero_neighbor_sample)>();
return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, node_time_dict, seed_time_dict,
edge_time_dict, edge_weight_dict, csc, replace, directed, disjoint,
temporal_strategy, return_edge_id);
edge_time_dict, edge_weight_dict, csc, replace, directed,
disjoint, temporal_strategy, return_edge_id);
}

std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
Expand Down Expand Up @@ -121,14 +121,16 @@ std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::dist_neighbor_sample", "")
.typed<decltype(dist_neighbor_sample)>();
return op.call(rowptr, col, seed, num_neighbors, node_time, seed_time, edge_time, edge_weight,
csc, replace, directed, disjoint, temporal_strategy);
return op.call(rowptr, col, seed, num_neighbors, node_time, seed_time,
edge_time, edge_weight, csc, replace, directed, disjoint,
temporal_strategy);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] "
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? edge_time = None, Tensor? "
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? "
"edge_time = None, Tensor? "
"edge_weight = None, bool csc = False, bool replace = False, bool "
"directed = True, bool disjoint = False, str temporal_strategy = "
"'uniform', bool return_edge_id = True) -> "
Expand All @@ -146,7 +148,8 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) {
"Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int "
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? edge_time = None, Tensor? "
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? "
"edge_time = None, Tensor? "
"edge_weight = None, bool csc = False, bool replace = False, bool "
"directed = True, bool disjoint = False, str temporal_strategy = "
"'uniform') -> (Tensor, Tensor, int[])"));
Expand Down
5 changes: 1 addition & 4 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,7 @@ def hetero_neighbor_sample(
for k, v in edge_weight_dict.items()
}
if edge_time_dict is not None:
edge_time_dict = {
TO_REL_TYPE[k]: v
for k, v in edge_time_dict.items()
}
edge_time_dict = {TO_REL_TYPE[k]: v for k, v in edge_time_dict.items()}

out = torch.ops.pyg.hetero_neighbor_sample(
node_types,
Expand Down
16 changes: 8 additions & 8 deletions test/csrc/sampler/test_neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,14 @@ TEST(BiasedNeighborTest, BasicAssertions) {
/*edge_time=*/edge_weight,
/*edge_weight=*/edge_weight);

auto expected_row = at::tensor({0, 1}, options);
EXPECT_TRUE(at::equal(std::get<0>(out), expected_row));
auto expected_col = at::tensor({2, 0}, options);
EXPECT_TRUE(at::equal(std::get<1>(out), expected_col));
auto expected_nodes = at::tensor({0, 1, 5}, options);
EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes));
auto expected_edges = at::tensor({0, 2}, options);
EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges));
auto expected_row = at::tensor({0, 1}, options);
EXPECT_TRUE(at::equal(std::get<0>(out), expected_row));
auto expected_col = at::tensor({2, 0}, options);
EXPECT_TRUE(at::equal(std::get<1>(out), expected_col));
auto expected_nodes = at::tensor({0, 1, 5}, options);
EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes));
auto expected_edges = at::tensor({0, 2}, options);
EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges));
}

TEST(HeteroBiasedNeighborTest, BasicAssertions) {
Expand Down

0 comments on commit 7c666d4

Please sign in to comment.