Skip to content

Commit

Permalink
fix bug to add large tensor support (#308)
Browse files Browse the repository at this point in the history
Signed-off-by: kaixuan <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
kaixuanliu and pre-commit-ci[bot] authored Mar 14, 2024
1 parent 0ae8cdf commit 0228406
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions pyg_lib/csrc/ops/cpu/radix_sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,16 @@ void radix_sort_kernel(K* input_keys,
V* input_values,
K* output_keys,
V* output_values,
int elements_count,
int* histogram,
int* histogram_ps,
int64_t elements_count,
int64_t* histogram,
int64_t* histogram_ps,
int pass) {
int tid = omp_get_thread_num();
int nthreads = omp_get_num_threads();
int elements_count_4 = elements_count / 4 * 4;
int64_t elements_count_4 = elements_count / 4 * 4;

int* local_histogram = &histogram[RDX_HIST_SIZE * tid];
int* local_histogram_ps = &histogram_ps[RDX_HIST_SIZE * tid];
int64_t* local_histogram = &histogram[RDX_HIST_SIZE * tid];
int64_t* local_histogram_ps = &histogram_ps[RDX_HIST_SIZE * tid];

// Step 1: compute histogram
for (int i = 0; i < RDX_HIST_SIZE; i++) {
Expand Down Expand Up @@ -97,7 +97,7 @@ void radix_sort_kernel(K* input_keys,
#pragma omp barrier
// Step 2: prefix sum
if (tid == 0) {
int sum = 0, prev_sum = 0;
int64_t sum = 0, prev_sum = 0;
for (int bins = 0; bins < RDX_HIST_SIZE; bins++) {
for (int t = 0; t < nthreads; t++) {
sum += histogram[t * RDX_HIST_SIZE + bins];
Expand All @@ -123,7 +123,7 @@ void radix_sort_kernel(K* input_keys,
int bin_3 = (key_3 >> (pass * 8)) & 0xFF;
int bin_4 = (key_4 >> (pass * 8)) & 0xFF;

int pos;
int64_t pos;
pos = local_histogram_ps[bin_1]++;
output_keys[pos] = key_1;
output_values[pos] = input_values[i];
Expand All @@ -140,7 +140,7 @@ void radix_sort_kernel(K* input_keys,
if (tid == (nthreads - 1)) {
for (int64_t i = elements_count_4; i < elements_count; ++i) {
K key = input_keys[i];
int pos = local_histogram_ps[(key >> (pass * 8)) & 0xFF]++;
int64_t pos = local_histogram_ps[(key >> (pass * 8)) & 0xFF]++;
output_keys[pos] = key;
output_values[pos] = input_values[i];
}
Expand All @@ -161,11 +161,12 @@ std::pair<K*, V*> radix_sort_parallel(K* inp_key_buf,
int64_t elements_count,
int64_t max_value) {
int maxthreads = omp_get_max_threads();
std::unique_ptr<int[]> histogram_tmp(new int[RDX_HIST_SIZE * maxthreads]);
std::unique_ptr<int[]> histogram_ps_tmp(
new int[RDX_HIST_SIZE * maxthreads + 1]);
int* histogram = histogram_tmp.get();
int* histogram_ps = histogram_ps_tmp.get();
std::unique_ptr<int64_t[]> histogram_tmp(
new int64_t[RDX_HIST_SIZE * maxthreads]);
std::unique_ptr<int64_t[]> histogram_ps_tmp(
new int64_t[RDX_HIST_SIZE * maxthreads + 1]);
int64_t* histogram = histogram_tmp.get();
int64_t* histogram_ps = histogram_ps_tmp.get();
if (max_value == 0) {
return std::make_pair(inp_key_buf, inp_value_buf);
}
Expand Down

0 comments on commit 0228406

Please sign in to comment.