Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Nov 14, 2023
1 parent 80e295b commit faece9a
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 120 deletions.
2 changes: 1 addition & 1 deletion benchmark/sampler/hetero_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def test_hetero_neighbor(dataset, **kwargs):
seed_dict,
num_neighbors_dict,
node_time_dict,
seed_time_dict=None,
edge_time_dict=None,
seed_time_dict=None,
edge_weight_dict=edge_weight_dict,
csc=True,
replace=False,
Expand Down
2 changes: 1 addition & 1 deletion benchmark/sampler/neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def test_neighbor(dataset, **kwargs):
seed,
num_neighbors,
node_time=node_time,
seed_time=None,
edge_time=None,
seed_time=None,
edge_weight=edge_weight,
replace=args.replace,
directed=args.directed,
Expand Down
82 changes: 40 additions & 42 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
#include "pyg_lib/csrc/utils/cpu/convert.h"
#include "pyg_lib/csrc/utils/types.h"

#include <iostream>

namespace pyg {
namespace sampler {

Expand Down Expand Up @@ -74,11 +72,11 @@ class NeighborSampler {
dst_mapper, generator, out_global_dst_nodes);
}

void edge_temporal_sample(const node_t global_src_node,
void node_temporal_sample(const node_t global_src_node,
const scalar_t local_src_node,
const temporal_t* edge_times,
const int64_t count,
const temporal_t seed_time,
const temporal_t* time,
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
Expand All @@ -87,28 +85,27 @@ class NeighborSampler {

if ((row_end - row_start == 0) || (count == 0))
return;

// Find new `row_end` such that all neighbors fulfill temporal constraints:
std::vector<int> v(row_end - row_start);
std::iota(std::begin(v), std::end(v), row_start);
auto it = std::upper_bound(v.begin(), v.end(), seed_time,
[&](const scalar_t& a, const scalar_t& b) {
return a < edge_times[b];
});
row_end = it - v.begin() + row_start;
auto it = std::upper_bound(
col_ + row_start, col_ + row_end, seed_time,
[&](const scalar_t& a, const scalar_t& b) { return a < time[b]; });
row_end = it - col_;

if (temporal_strategy_ == "last" && count >= 0) {
row_start = std::max(row_start, (scalar_t)(row_end - count));
}

if (row_end - row_start > 1) {
TORCH_CHECK(edge_times[row_start] <= edge_times[row_end - 1],
TORCH_CHECK(time[col_[row_start]] <= time[col_[row_end - 1]],
"Found invalid non-sorted temporal neighborhood");
}

_sample(global_src_node, local_src_node, row_start, row_end, count,
dst_mapper, generator, out_global_dst_nodes);
}

void node_temporal_sample(const node_t global_src_node,
void edge_temporal_sample(const node_t global_src_node,
const scalar_t local_src_node,
const int64_t count,
const temporal_t seed_time,
Expand All @@ -124,16 +121,15 @@ class NeighborSampler {

// Find new `row_end` such that all neighbors fulfill temporal constraints:
auto it = std::upper_bound(
col_ + row_start, col_ + row_end, seed_time,
[&](const scalar_t& a, const scalar_t& b) { return a < time[b]; });
row_end = it - col_;
time + row_start, time + row_end, seed_time,
[&](const scalar_t& a, const scalar_t& b) { return a < b; });
row_end = it - time;

if (temporal_strategy_ == "last" && count >= 0) {
row_start = std::max(row_start, (scalar_t)(row_end - count));
}

if (row_end - row_start > 1) {
TORCH_CHECK(time[col_[row_start]] <= time[col_[row_end - 1]],
TORCH_CHECK(time[row_start] <= time[row_end - 1],
"Found invalid non-sorted temporal neighborhood");
}

Expand Down Expand Up @@ -344,16 +340,17 @@ sample(const at::Tensor& rowptr,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
const bool csc,
const std::string temporal_strategy) {
TORCH_CHECK(!node_time.has_value() || disjoint,
"Temporal sampling needs to create disjoint subgraphs");

TORCH_CHECK(!edge_time.has_value() || disjoint,
"Temporal sampling needs to create disjoint subgraphs");
TORCH_CHECK(node_time.has_value() && edge_time.has_value(),
"Only one of node-level or edge-level sampling is supported ");

TORCH_CHECK(rowptr.is_contiguous(), "Non-contiguous 'rowptr'");
TORCH_CHECK(col.is_contiguous(), "Non-contiguous 'col'");
Expand All @@ -372,7 +369,6 @@ sample(const at::Tensor& rowptr,
}
TORCH_CHECK(!(node_time.has_value() && edge_weight.has_value()),
"Biased node temporal sampling not yet supported");

TORCH_CHECK(!(edge_time.has_value() && edge_weight.has_value()),
"Biased edge temporal sampling not yet supported");

Expand Down Expand Up @@ -463,9 +459,9 @@ sample(const at::Tensor& rowptr,
sampler.edge_temporal_sample(
/*global_src_node=*/sampled_nodes[i],
/*local_src_node=*/i,
/*edge_time=*/edge_time_data,
/*count=*/count,
/*seed_time=*/seed_times[batch_idx],
/*time=*/edge_time_data,
/*dst_mapper=*/mapper,
/*generator=*/generator,
/*out_global_dst_nodes=*/sampled_nodes);
Expand All @@ -478,7 +474,8 @@ sample(const at::Tensor& rowptr,
const auto batch_idx = sampled_nodes[i].first;
sampler.node_temporal_sample(
/*global_src_node=*/sampled_nodes[i],
/*local_src_node=*/i, /*count=*/count,
/*local_src_node=*/i,
/*count=*/count,
/*seed_time=*/seed_times[batch_idx],
/*time=*/time_data,
/*dst_mapper=*/mapper,
Expand Down Expand Up @@ -529,15 +526,17 @@ sample(const std::vector<node_type>& node_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& node_time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
const bool csc,
const std::string temporal_strategy) {
TORCH_CHECK(!node_time_dict.has_value() || disjoint,
"Node temporal sampling needs to create disjoint subgraphs");
TORCH_CHECK(!edge_time_dict.has_value() || disjoint,
"Edge temporal sampling needs to create disjoint subgraphs");
TORCH_CHECK(node_time_dict.has_value() && edge_time_dict.has_value(),
"Only one of node-level or edge-level sampling is supported ");

for (const auto& kv : rowptr_dict) {
const at::Tensor& rowptr = kv.value();
Expand All @@ -554,13 +553,13 @@ sample(const std::vector<node_type>& node_types,
if (node_time_dict.has_value()) {
for (const auto& kv : node_time_dict.value()) {
const at::Tensor& time = kv.value();
TORCH_CHECK(time.is_contiguous(), "Non-contiguous 'node time'");
TORCH_CHECK(time.is_contiguous(), "Non-contiguous 'node_time'");
}
}
if (edge_time_dict.has_value()) {
for (const auto& kv : edge_time_dict.value()) {
const at::Tensor& time = kv.value();
TORCH_CHECK(time.is_contiguous(), "Non-contiguous 'edge time'");
TORCH_CHECK(time.is_contiguous(), "Non-contiguous 'edge_time'");
}
}
if (seed_time_dict.has_value()) {
Expand All @@ -571,6 +570,8 @@ sample(const std::vector<node_type>& node_types,
}
TORCH_CHECK(!(node_time_dict.has_value() && edge_weight_dict.has_value()),
"Biased temporal sampling not yet supported");
TORCH_CHECK(!(edge_time_dict.has_value() && edge_weight_dict.has_value()),
"Biased temporal sampling not yet supported");

c10::Dict<rel_type, at::Tensor> out_row_dict, out_col_dict;
c10::Dict<node_type, at::Tensor> out_node_id_dict;
Expand Down Expand Up @@ -686,9 +687,6 @@ sample(const std::vector<node_type>& node_types,
for (size_t i = 0; i < seed.numel(); ++i) {
seed_times.push_back(time_data[seed_data[i]]);
}
} else if (edge_time_dict.has_value()) {
// Not supported
exit(0);
}
}

Expand Down Expand Up @@ -753,25 +751,25 @@ sample(const std::vector<node_type>& node_types,
} else if constexpr (!std::is_scalar<node_t>::value) {
if (edge_time_dict.has_value() &&
edge_time_dict.value().contains(to_rel_type(k))) {
// Edge temporal sampling
const at::Tensor& edge_times =
// Edge-level temporal sampling:
const at::Tensor& edge_time =
edge_time_dict.value().at(to_rel_type(k));
const auto edge_times_data =
edge_times.data_ptr<temporal_t>();
const auto edge_time_data =
edge_time.data_ptr<temporal_t>();
for (size_t i = begin; i < end; ++i) {
const auto batch_idx = src_sampled_nodes[i].first;
sampler.edge_temporal_sample(
/*global_src_node=*/src_sampled_nodes[i],
/*local_src_node=*/i,
/*edge_times=*/edge_times_data,
/*count=*/count,
/*seed_time=*/seed_times[batch_idx],
/*time=*/edge_time_data,
/*dst_mapper=*/dst_mapper,
/*generator=*/generator,
/*out_global_dst_nodes=*/dst_sampled_nodes);
}
} else {
// Temporal sampling:
// Node-level temporal sampling:
const at::Tensor& dst_time = node_time_dict.value().at(dst);
const auto dst_time_data = dst_time.data_ptr<temporal_t>();
for (size_t i = begin; i < end; ++i) {
Expand Down Expand Up @@ -901,8 +899,8 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
Expand All @@ -912,7 +910,7 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
bool return_edge_id) {
const auto out = [&] {
DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col,
seed, num_neighbors, node_time, seed_time, edge_time,
seed, num_neighbors, node_time, edge_time, seed_time,
edge_weight, csc, temporal_strategy);
}();
return std::make_tuple(std::get<0>(out), std::get<1>(out), std::get<2>(out),
Expand All @@ -933,8 +931,8 @@ hetero_neighbor_sample_kernel(
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& node_time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
bool replace,
Expand All @@ -944,8 +942,8 @@ hetero_neighbor_sample_kernel(
bool return_edge_id) {
DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, node_types,
edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, node_time_dict, seed_time_dict,
edge_time_dict, edge_weight_dict, csc, temporal_strategy);
num_neighbors_dict, node_time_dict, edge_time_dict,
seed_time_dict, edge_weight_dict, csc, temporal_strategy);
}

std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>>
Expand All @@ -954,8 +952,8 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& seed,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
Expand All @@ -964,7 +962,7 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr,
std::string temporal_strategy) {
const auto out = [&] {
DISPATCH_DIST_SAMPLE(replace, directed, disjoint, rowptr, col, seed,
{num_neighbors}, node_time, seed_time, edge_time,
{num_neighbors}, node_time, edge_time, seed_time,
edge_weight, csc, temporal_strategy);
}();
return std::make_tuple(std::get<2>(out), std::get<3>(out).value(),
Expand Down
4 changes: 2 additions & 2 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
Expand All @@ -40,8 +40,8 @@ hetero_neighbor_sample_kernel(
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& node_time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
bool replace,
Expand Down
Loading

0 comments on commit faece9a

Please sign in to comment.