From b242f55d3d8d7ea2a0bce05fa0e43e77805ca586 Mon Sep 17 00:00:00 2001 From: Kinga Gajdamowicz Date: Fri, 23 Feb 2024 11:43:45 +0100 Subject: [PATCH] Increment subgraph id globally for all seed nodes (#304) - All seed nodes subgraph ids have their own unique subgraph id regardless of type. - Additionally, the method of retrieving information about the number of nodes of a given type has been changed. --- pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp b/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp index 29ebfecf0..805dfef87 100644 --- a/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp @@ -166,14 +166,10 @@ relabel( threads_edge_types.push_back({edge_types}); } - int64_t N = 0; - for (const auto& kv : num_nodes_dict) { - N += kv.value() > 0 ? kv.value() : 0; - } - for (const auto& k : node_types) { sampled_nodes_data_dict.insert( {k, sampled_nodes_with_duplicates_dict.at(k).data_ptr()}); + int64_t N = num_nodes_dict.at(k); mapper_dict.insert({k, Mapper(N)}); slice_dict[k] = {0, 0}; srcs_offset_dict[k] = 0; @@ -182,6 +178,7 @@ relabel( {k, batch_dict.value().at(k).data_ptr()}); } } + scalar_t batch_idx = 0; for (const auto& kv : seed_dict) { const at::Tensor& seed = kv.value(); if constexpr (!disjoint) { @@ -190,7 +187,8 @@ relabel( auto& mapper = mapper_dict.at(kv.key()); const auto seed_data = seed.data_ptr(); for (size_t i = 0; i < seed.numel(); ++i) { - mapper.insert({i, seed_data[i]}); + mapper.insert({batch_idx, seed_data[i]}); + batch_idx++; } } }