Skip to content

Commit

Permalink
GH-44513: [C++] Fix overflow issues for large build side in swiss join (
Browse files Browse the repository at this point in the history
#45108)

### Rationale for this change

#44513 triggers two distinct overflow issues within swiss join, both happening when the build side table contains large enough number of rows or distinct keys. (Cases at this extent of hash join build side are rather rare, so we haven't seen them reported until now):

1. The first issue is, our swiss table implementation takes the higher `N` bits of 32-bit hash value as the index to a buffer storing "block"s (a block contains `8` key - in some code also referred to as "group" - ids). This `N`-bit number is further multiplied by the size of a block, which is also related to `N`. The `N` in the case of #44513 is `26` and a block takes `40` bytes. So the multiply is possible to produce a number over `1 << 31` (negative when interpreted as signed 32bit). In our AVX2 specialization of accessing the block buffer https://github.com/apache/arrow/blob/0a00e25f2f6fb927fb555b69038d0be9b9d9f265/cpp/src/arrow/compute/key_map_internal_avx2.cc#L404 , the issue like #41813 (comment) shows up. This is the actual issue that directly produced the segfault in #44513.
2. The other issue is, we take `7` bits of the 32-bit hash value after `N` as a "stamp" (to quick fail the hash comparison). But when `N` is greater than `25`, some arithmetic code like https://github.com/apache/arrow/blob/0a00e25f2f6fb927fb555b69038d0be9b9d9f265/cpp/src/arrow/compute/key_map_internal.cc#L397 (`bits_hash_` is `constexpr 32`, `log_blocks_` is `N`, `bits_stamp_` is `constexpr 7`, this is to retrieve the stamp from a hash) produces `hash >> -1` aka `hash >> 0xFFFFFFFF` aka `hash >> 31` (the heading `1`s are trimmed) then the stamp value is wrong and results in false-mismatched rows. This is the reason of my false positive run in #44513 (comment) .

### What changes are included in this PR?

For issue 1, use 64-bit index gather intrinsic to avoid the offset overflow.

For issue 2, do not right-shift the hash if `N + 7 >= 32`. This is actually allowing the bits overlapping between block id (the `N` bits) and stamp (the `7` bits). Though this may introduce more false-positive hash comparisons (thus worsen the performance), I think this is still more reasonable than brutally failing for `N > 25`. I introduce two members `bits_shift_for_block_and_stamp_` and `bits_shift_for_block_`, which are derived from `log_blocks_` - esp. set to `0` and `32 - N` when `N + 7 >= 32`, this is to avoid branching like `if (log_blocks_ + bits_stamp_ > bits_hash_)` in tight loops.

### Are these changes tested?

The fix is manually tested with the original case in my local. (I do have a concrete C++ UT to verify the fix but it requires too much resource and runs for too long time so it is impractical to run in any reasonable CI environment.)

### Are there any user-facing changes?

None.

* GitHub Issue: #44513

Lead-authored-by: Rossi Sun <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Rossi Sun <[email protected]>
  • Loading branch information
zanmato1984 and pitrou authored Jan 13, 2025
1 parent 1f63646 commit 32fcd18
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 32 deletions.
21 changes: 14 additions & 7 deletions cpp/src/arrow/compute/key_map_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,9 @@ void SwissTable::early_filter_imp(const int num_keys, const uint32_t* hashes,
// Extract from hash: block index and stamp
//
uint32_t hash = hashes[i];
uint32_t iblock = hash >> (bits_hash_ - bits_stamp_ - log_blocks_);
uint32_t iblock = hash >> bits_shift_for_block_and_stamp_;
uint32_t stamp = iblock & stamp_mask;
iblock >>= bits_stamp_;
iblock >>= bits_shift_for_block_;

uint32_t num_block_bytes = num_groupid_bits + 8;
const uint8_t* blockbase =
Expand Down Expand Up @@ -399,7 +399,7 @@ bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_sl
const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
constexpr uint64_t stamp_mask = 0x7f;
const int stamp =
static_cast<int>((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask);
static_cast<int>((hash >> bits_shift_for_block_and_stamp_) & stamp_mask);
uint64_t start_slot_id = wrap_global_slot_id(in_slot_id);
int match_found;
int local_slot;
Expand Down Expand Up @@ -659,6 +659,9 @@ Status SwissTable::grow_double() {
int num_group_id_bits_after = num_groupid_bits_from_log_blocks(log_blocks_ + 1);
uint64_t group_id_mask_before = ~0ULL >> (64 - num_group_id_bits_before);
int log_blocks_after = log_blocks_ + 1;
int bits_shift_for_block_and_stamp_after =
ComputeBitsShiftForBlockAndStamp(log_blocks_after);
int bits_shift_for_block_after = ComputeBitsShiftForBlock(log_blocks_after);
uint64_t block_size_before = (8 + num_group_id_bits_before);
uint64_t block_size_after = (8 + num_group_id_bits_after);
uint64_t block_size_total_after = (block_size_after << log_blocks_after) + padding_;
Expand Down Expand Up @@ -701,8 +704,7 @@ Status SwissTable::grow_double() {
}

int ihalf = block_id_new & 1;
uint8_t stamp_new =
hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask;
uint8_t stamp_new = (hash >> bits_shift_for_block_and_stamp_after) & stamp_mask;
uint64_t group_id_bit_offs = j * num_group_id_bits_before;
uint64_t group_id =
(util::SafeLoadAs<uint64_t>(block_base + 8 + (group_id_bit_offs >> 3)) >>
Expand Down Expand Up @@ -744,8 +746,7 @@ Status SwissTable::grow_double() {
(util::SafeLoadAs<uint64_t>(block_base + 8 + (group_id_bit_offs >> 3)) >>
(group_id_bit_offs & 7)) &
group_id_mask_before;
uint8_t stamp_new =
hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask;
uint8_t stamp_new = (hash >> bits_shift_for_block_and_stamp_after) & stamp_mask;

uint8_t* block_base_new =
blocks_new->mutable_data() + block_id_new * block_size_after;
Expand Down Expand Up @@ -773,6 +774,8 @@ Status SwissTable::grow_double() {
blocks_ = std::move(blocks_new);
hashes_ = std::move(hashes_new_buffer);
log_blocks_ = log_blocks_after;
bits_shift_for_block_and_stamp_ = bits_shift_for_block_and_stamp_after;
bits_shift_for_block_ = bits_shift_for_block_after;

return Status::OK();
}
Expand All @@ -784,6 +787,8 @@ Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, int log_blocks
log_minibatch_ = util::MiniBatch::kLogMiniBatchLength;

log_blocks_ = log_blocks;
bits_shift_for_block_and_stamp_ = ComputeBitsShiftForBlockAndStamp(log_blocks_);
bits_shift_for_block_ = ComputeBitsShiftForBlock(log_blocks_);
int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
num_inserted_ = 0;

Expand Down Expand Up @@ -820,6 +825,8 @@ void SwissTable::cleanup() {
hashes_ = nullptr;
}
log_blocks_ = 0;
bits_shift_for_block_and_stamp_ = ComputeBitsShiftForBlockAndStamp(log_blocks_);
bits_shift_for_block_ = ComputeBitsShiftForBlock(log_blocks_);
num_inserted_ = 0;
}

Expand Down
25 changes: 23 additions & 2 deletions cpp/src/arrow/compute/key_map_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,23 @@ class ARROW_EXPORT SwissTable {
// Resize large hash tables when 75% full.
Status grow_double();

// When log_blocks is greater than 25, there will be overlapping bits between block id
// and stamp within a 32-bit hash value. So we must check if this is the case when
// right shifting a hash value to retrieve block id and stamp. The following two
// functions derive the number of bits to right shift from the given log_blocks.
static int ComputeBitsShiftForBlockAndStamp(int log_blocks) {
if (ARROW_PREDICT_FALSE(log_blocks + bits_stamp_ > bits_hash_)) {
return 0;
}
return bits_hash_ - log_blocks - bits_stamp_;
}
static int ComputeBitsShiftForBlock(int log_blocks) {
if (ARROW_PREDICT_FALSE(log_blocks + bits_stamp_ > bits_hash_)) {
return bits_hash_ - log_blocks;
}
return bits_stamp_;
}

// Number of hash bits stored in slots in a block.
// The highest bits of hash determine block id.
// The next set of highest bits is a "stamp" stored in a slot in a block.
Expand All @@ -214,6 +231,11 @@ class ARROW_EXPORT SwissTable {
int log_minibatch_;
// Base 2 log of the number of blocks
int log_blocks_ = 0;
// The following two variables are derived from log_blocks_ as log_blocks_ changes, and
// used in tight loops to avoid calling the ComputeXXX functions (introducing a
// branching on whether log_blocks_ + bits_stamp_ > bits_hash_).
int bits_shift_for_block_and_stamp_ = ComputeBitsShiftForBlockAndStamp(log_blocks_);
int bits_shift_for_block_ = ComputeBitsShiftForBlock(log_blocks_);
// Number of keys inserted into hash table
uint32_t num_inserted_ = 0;

Expand Down Expand Up @@ -271,8 +293,7 @@ void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash,
constexpr uint64_t stamp_mask = 0x7f;

int start_slot = (slot_id & 7);
int stamp =
static_cast<int>((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask);
int stamp = static_cast<int>((hash >> bits_shift_for_block_and_stamp_) & stamp_mask);
uint64_t block_id = slot_id >> 3;
uint8_t* blockbase = blocks_->mutable_data() + num_block_bytes * block_id;

Expand Down
55 changes: 32 additions & 23 deletions cpp/src/arrow/compute/key_map_internal_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,9 @@ int SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* h
// Calculate block index and hash stamp for a byte in a block
//
__m256i vhash = _mm256_loadu_si256(vhash_ptr + i);
__m256i vblock_id = _mm256_srlv_epi32(
vhash, _mm256_set1_epi32(bits_hash_ - bits_stamp_ - log_blocks_));
__m256i vblock_id = _mm256_srli_epi32(vhash, bits_shift_for_block_and_stamp_);
__m256i vstamp = _mm256_and_si256(vblock_id, vstamp_mask);
vblock_id = _mm256_srli_epi32(vblock_id, bits_stamp_);
vblock_id = _mm256_srli_epi32(vblock_id, bits_shift_for_block_);

// We now split inputs and process 4 at a time,
// in order to process 64-bit blocks
Expand Down Expand Up @@ -301,19 +300,15 @@ int SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t*
_mm256_and_si256(vhash2, _mm256_set1_epi32(0xffff0000)));
vhash1 = _mm256_or_si256(_mm256_srli_epi32(vhash1, 16),
_mm256_and_si256(vhash3, _mm256_set1_epi32(0xffff0000)));
__m256i vstamp_A = _mm256_and_si256(
_mm256_srlv_epi32(vhash0, _mm256_set1_epi32(16 - log_blocks_ - 7)),
_mm256_set1_epi16(0x7f));
__m256i vstamp_B = _mm256_and_si256(
_mm256_srlv_epi32(vhash1, _mm256_set1_epi32(16 - log_blocks_ - 7)),
_mm256_set1_epi16(0x7f));
__m256i vstamp_A = _mm256_and_si256(_mm256_srli_epi32(vhash0, 16 - log_blocks_ - 7),
_mm256_set1_epi16(0x7f));
__m256i vstamp_B = _mm256_and_si256(_mm256_srli_epi32(vhash1, 16 - log_blocks_ - 7),
_mm256_set1_epi16(0x7f));
__m256i vstamp = _mm256_or_si256(vstamp_A, _mm256_slli_epi16(vstamp_B, 8));
__m256i vblock_id_A =
_mm256_and_si256(_mm256_srlv_epi32(vhash0, _mm256_set1_epi32(16 - log_blocks_)),
_mm256_set1_epi16(block_id_mask));
__m256i vblock_id_B =
_mm256_and_si256(_mm256_srlv_epi32(vhash1, _mm256_set1_epi32(16 - log_blocks_)),
_mm256_set1_epi16(block_id_mask));
__m256i vblock_id_A = _mm256_and_si256(_mm256_srli_epi32(vhash0, 16 - log_blocks_),
_mm256_set1_epi16(block_id_mask));
__m256i vblock_id_B = _mm256_and_si256(_mm256_srli_epi32(vhash1, 16 - log_blocks_),
_mm256_set1_epi16(block_id_mask));
__m256i vblock_id = _mm256_or_si256(vblock_id_A, _mm256_slli_epi16(vblock_id_B, 8));

// Visit all block bytes in reverse order (overwriting data on multiple matches)
Expand Down Expand Up @@ -392,16 +387,30 @@ int SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashe
} else {
for (int i = 0; i < num_keys / unroll; ++i) {
__m256i hash = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(hashes) + i);
// Extend hash and local_slot to 64-bit to compute 64-bit group id offsets to
// gather from. This is to prevent index overflow issues in GH-44513.
// NB: Use zero-extend conversion for unsigned hash.
__m256i hash_lo = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(hash));
__m256i hash_hi = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(hash, 1));
__m256i local_slot =
_mm256_set1_epi64x(reinterpret_cast<const uint64_t*>(local_slots)[i]);
local_slot = _mm256_shuffle_epi8(
local_slot, _mm256_setr_epi32(0x80808000, 0x80808001, 0x80808002, 0x80808003,
0x80808004, 0x80808005, 0x80808006, 0x80808007));
local_slot = _mm256_mullo_epi32(local_slot, _mm256_set1_epi32(byte_size));
__m256i pos = _mm256_srlv_epi32(hash, _mm256_set1_epi32(bits_hash_ - log_blocks_));
pos = _mm256_mullo_epi32(pos, _mm256_set1_epi32(byte_multiplier));
pos = _mm256_add_epi32(pos, local_slot);
__m256i group_id = _mm256_i32gather_epi32(elements, pos, 1);
__m256i local_slot_lo = _mm256_shuffle_epi8(
local_slot, _mm256_setr_epi32(0x80808000, 0x80808080, 0x80808001, 0x80808080,
0x80808002, 0x80808080, 0x80808003, 0x80808080));
__m256i local_slot_hi = _mm256_shuffle_epi8(
local_slot, _mm256_setr_epi32(0x80808004, 0x80808080, 0x80808005, 0x80808080,
0x80808006, 0x80808080, 0x80808007, 0x80808080));
local_slot_lo = _mm256_mul_epu32(local_slot_lo, _mm256_set1_epi32(byte_size));
local_slot_hi = _mm256_mul_epu32(local_slot_hi, _mm256_set1_epi32(byte_size));
__m256i pos_lo = _mm256_srli_epi64(hash_lo, bits_hash_ - log_blocks_);
__m256i pos_hi = _mm256_srli_epi64(hash_hi, bits_hash_ - log_blocks_);
pos_lo = _mm256_mul_epu32(pos_lo, _mm256_set1_epi32(byte_multiplier));
pos_hi = _mm256_mul_epu32(pos_hi, _mm256_set1_epi32(byte_multiplier));
pos_lo = _mm256_add_epi64(pos_lo, local_slot_lo);
pos_hi = _mm256_add_epi64(pos_hi, local_slot_hi);
__m128i group_id_lo = _mm256_i64gather_epi32(elements, pos_lo, 1);
__m128i group_id_hi = _mm256_i64gather_epi32(elements, pos_hi, 1);
__m256i group_id = _mm256_set_m128i(group_id_hi, group_id_lo);
group_id = _mm256_and_si256(group_id, _mm256_set1_epi32(mask));
_mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id);
}
Expand Down

0 comments on commit 32fcd18

Please sign in to comment.