-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GH-44513: [C++] Fix overflow issues for large build side in swiss join #45108
base: main
Are you sure you want to change the base?
Conversation
Thanks for opening a pull request! If this is not a minor PR. Could you open an issue for this pull request on GitHub? https://github.com/apache/arrow/issues/new/choose Opening GitHub issues ahead of time contributes to the Openness of the Apache Arrow project. Then could you also rename the pull request title in the following format?
or
See also: |
|
4bc9967
to
fe35443
Compare
Hi @pitrou , would you help to take a look? Thanks. |
// number of bits to right shift, rather than branching on whether log_blocks_ > 25 | ||
// every time in tight loops. | ||
int bits_shift_for_block_and_stamp_ = bits_hash_ - log_blocks_ - bits_stamp_; | ||
int bits_shift_for_block_ = bits_stamp_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the computation is repeated several times, perhaps we can have a short helper function to factor it out? Something like:
static std::pair<int, int> ComputeBitShifts(int log_blocks) {
if (log_blocks + bits_stamp_ > bits_hash_) {
return {0, bits_hash_ - log_blocks};
} else {
return {bits_hash_ - log_blocks - bits_stamp_, bits_stamp_};
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. Done.
// This is to prevent index overflow issues in GH-44513. | ||
// NB: Use zero-extend conversion for unsigned hash. | ||
__m256i hash_lo = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(hash)); | ||
__m256i hash_hi = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(hash, 1)); | ||
__m256i local_slot = | ||
_mm256_set1_epi64x(reinterpret_cast<const uint64_t*>(local_slots)[i]); | ||
local_slot = _mm256_shuffle_epi8( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm... so this first expands _mm256_shuffle_epi8
from 8-bit to 32-bit lanes, and then _mm256_cvtepi32_epi64
below expands it from 32-bit to 64-bit lanes? Would it be quicker to shuffle directly from 8-bit to 64-bit (twice, I suppose)
(interestingly, _mm256_shuffle_epi8
is faster than _mm256_cvtepi32_epi64
according to https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_shuffle_epi8&ig_expand=1798,6006,1628,6006)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking that we can save one multiply of local_offset * byte_size
. But yeah, once we shuffled to 64-bit lanes, we can use _mm256_mul_epi32
(5 cycles) to replace _mm256_mullo_epi32
(10 cycles), then we have 2 _mm256_shuffle_epi8
s (1 cycle each) + 2 _mm256_mul_epi32
s = 12 cycles in total, VS., 1 _mm256_shuffle_epi8
+ 1 _mm256_mullo_epi32
+ 2 _mm256_cvtepi32_epi64
(3 cycles each) = 17 cycles in total, which is still a win.
I've updated. Thank you for this.
__m256i local_slot_hi = | ||
_mm256_cvtepi32_epi64(_mm256_extracti128_si256(local_slot, 1)); | ||
__m256i pos_lo = | ||
_mm256_srlv_epi64(hash_lo, _mm256_set1_epi64x(bits_hash_ - log_blocks_)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, why not _mm256_srli_epi64(hash_lo, bits_hash_ - log_blocks_)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just copied from the original code, plus I wasn't aware of _mm256_srli_epi64
then - still learning :)
Updated here and a couple of other unnecessary vector shifting. Thank you!
pos_lo = _mm256_mul_epi32(pos_lo, _mm256_set1_epi32(byte_multiplier)); | ||
pos_hi = _mm256_mul_epi32(pos_hi, _mm256_set1_epi32(byte_multiplier)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the record, why are we multiplying in the signed domain rather than unsigned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we should use unsigned multiply.
But actually they are the same in this specific case (i.e., both operands are less than 0x80000000
- note the log_blocks_
is strictly less than 32
). Even the result is larger than uint32_max
, _mm256_mul_epi32
won't do sign-extension.
Anyway, I'll update. Thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@ursabot please benchmark |
1 similar comment
@ursabot please benchmark |
Commit 4462ceb already has scheduled benchmark runs. |
Thanks for your patience. Conbench analyzed the 3 benchmarking runs that have been run so far on PR commit 4462ceb. There were 29 benchmark results with an error:
There weren't enough matching historic benchmark results to make a call on whether there were regressions. The full Conbench report has more details. |
Rationale for this change
#44513 triggers two distinct overflow issues within swiss join, both happening when the build side table contains large enough number of rows or distinct keys. (Cases at this extent of hash join build side are rather rare, so we haven't seen them reported until now):
N
bits of 32-bit hash value as the index to a buffer storing "block"s (a block contains8
key - in some code also referred to as "group" - ids). ThisN
-bit number is further multiplied by the size of a block, which is also related toN
. TheN
in the case of [C++][Python] Pyarrow.Table.join() breaks on large tables v.18.0.0.dev486 #44513 is26
and a block takes40
bytes. So the multiply is possible to produce a number over1 << 31
(negative when interpreted as signed 32bit). In our AVX2 specialization of accessing the block bufferarrow/cpp/src/arrow/compute/key_map_internal_avx2.cc
Line 404 in 0a00e25
7
bits of the 32-bit hash value afterN
as a "stamp" (to quick fail the hash comparison). But whenN
is greater than25
, some arithmetic code likearrow/cpp/src/arrow/compute/key_map_internal.cc
Line 397 in 0a00e25
bits_hash_
isconstexpr 32
,log_blocks_
isN
,bits_stamp_
isconstexpr 7
, this is to retrieve the stamp from a hash) produceshash >> -1
akahash >> 0xFFFFFFFF
akahash >> 31
(the heading1
s are trimmed) then the stamp value is wrong and results in false-mismatched rows. This is the reason of my false positive run in [C++][Python] Pyarrow.Table.join() breaks on large tables v.18.0.0.dev486 #44513 (comment) .What changes are included in this PR?
For issue 1, use 64-bit index gather intrinsic to avoid the offset overflow.
For issue 2, do not right-shift the hash if
N + 7 >= 32
. This is actually allowing the bits overlapping between block id (theN
bits) and stamp (the7
bits). Though this may introduce more false-positive hash comparisons (thus worsen the performance), I think this is still more reasonable than brutally failing forN > 25
. I introduce two membersbits_shift_for_block_and_stamp_
andbits_shift_for_block_
, which are derived fromlog_blocks_
- esp. set to0
and32 - N
whenN + 7 >= 32
, this is to avoid branching likeif (log_blocks_ + bits_stamp_ > bits_hash_)
in tight loops.Are these changes tested?
The fix is manually tested with the original case in my local. (I do have a concrete C++ UT to verify the fix but it requires too much resource and runs for too long time so it is impractical to run in any reasonable CI environment.)
Are there any user-facing changes?
None.