Skip to content

Commit

Permalink
Increment subgraph id globally for all seed nodes (#304)
Browse files Browse the repository at this point in the history
- All seed nodes subgraph ids have their own unique subgraph id
regardless of type.
- Additionally, the method of retrieving information about the number of
nodes of a given type has been changed.
  • Loading branch information
kgajdamo authored Feb 23, 2024
1 parent 00e9def commit b242f55
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,10 @@ relabel(
threads_edge_types.push_back({edge_types});
}

int64_t N = 0;
for (const auto& kv : num_nodes_dict) {
N += kv.value() > 0 ? kv.value() : 0;
}

for (const auto& k : node_types) {
sampled_nodes_data_dict.insert(
{k, sampled_nodes_with_duplicates_dict.at(k).data_ptr<scalar_t>()});
int64_t N = num_nodes_dict.at(k);
mapper_dict.insert({k, Mapper<node_t, scalar_t>(N)});
slice_dict[k] = {0, 0};
srcs_offset_dict[k] = 0;
Expand All @@ -182,6 +178,7 @@ relabel(
{k, batch_dict.value().at(k).data_ptr<scalar_t>()});
}
}
scalar_t batch_idx = 0;
for (const auto& kv : seed_dict) {
const at::Tensor& seed = kv.value();
if constexpr (!disjoint) {
Expand All @@ -190,7 +187,8 @@ relabel(
auto& mapper = mapper_dict.at(kv.key());
const auto seed_data = seed.data_ptr<scalar_t>();
for (size_t i = 0; i < seed.numel(); ++i) {
mapper.insert({i, seed_data[i]});
mapper.insert({batch_idx, seed_data[i]});
batch_idx++;
}
}
}
Expand Down

0 comments on commit b242f55

Please sign in to comment.