Skip to content

Commit

Permalink
Break
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Dec 27, 2024
1 parent 1565723 commit bcc43da
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 70 deletions.
38 changes: 17 additions & 21 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,6 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
//
int64_t source_group_id_bits =
SwissTable::num_groupid_bits_from_log_blocks(source->log_blocks());
uint64_t source_group_id_mask = ~0ULL >> (64 - source_group_id_bits);
int64_t source_block_bytes =
SwissTable::num_block_bytes_from_num_groupid_bits(source_group_id_bits);
ARROW_DCHECK(source_block_bytes % sizeof(uint64_t) == 0);
Expand All @@ -644,29 +643,28 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
// partition.
//
ARROW_DCHECK(num_partition_bits <= target->log_blocks());
int64_t target_max_block_id =
uint32_t target_max_block_id =
((partition_id + 1) << (target->log_blocks() - num_partition_bits)) - 1;

overflow_group_ids->clear();
overflow_hashes->clear();

// For each source block...
int64_t source_blocks = 1LL << source->log_blocks();
for (int64_t block_id = 0; block_id < source_blocks; ++block_id) {
uint32_t source_blocks = 1 << source->log_blocks();
for (uint32_t block_id = 0; block_id < source_blocks; ++block_id) {
uint8_t* block_bytes = source->blocks() + block_id * source_block_bytes;
uint64_t block = *reinterpret_cast<const uint64_t*>(block_bytes);

// For each non-empty source slot...
constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL;
constexpr int kSlotsPerBlock = 8;
int num_full_slots =
kSlotsPerBlock - static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
int num_full_slots = SwissTable::kSlotsPerBlock -
static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
for (int local_slot_id = 0; local_slot_id < num_full_slots; ++local_slot_id) {
// Read group id and hash for this slot.
//
uint64_t group_id =
source->extract_group_id(block_bytes, local_slot_id, source_group_id_mask);
int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id;
uint32_t group_id =
source->extract_group_id(block_bytes, local_slot_id, source_group_id_bits);
uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id);
uint32_t hash = source->hashes()[global_slot_id];
// Insert partition id into the highest bits of hash, shifting the
// remaining hash bits right.
Expand All @@ -689,12 +687,12 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
}
}

inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_id,
uint32_t hash, int64_t max_block_id) {
inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint32_t group_id,
uint32_t hash, uint32_t max_block_id) {
// Load the first block to visit for this hash
//
int64_t block_id = hash >> (SwissTable::bits_hash_ - target->log_blocks());
int64_t block_id_mask = ((1LL << target->log_blocks()) - 1);
uint32_t block_id = SwissTable::block_id_from_hash(hash, target->log_blocks());
uint32_t block_id_mask = (1 << target->log_blocks()) - 1;
int64_t num_group_id_bits =
SwissTable::num_groupid_bits_from_log_blocks(target->log_blocks());
int64_t num_block_bytes =
Expand All @@ -715,19 +713,17 @@ inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_i
if ((block & kHighBitOfEachByte) == 0) {
return false;
}
constexpr int kSlotsPerBlock = 8;
int local_slot_id =
kSlotsPerBlock - static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id;
target->insert_into_empty_slot(static_cast<uint32_t>(global_slot_id), hash,
static_cast<uint32_t>(group_id));
int local_slot_id = SwissTable::kSlotsPerBlock -
static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id);
target->insert_into_empty_slot(global_slot_id, hash, group_id);
return true;
}

void SwissTableMerge::InsertNewGroups(SwissTable* target,
const std::vector<uint32_t>& group_ids,
const std::vector<uint32_t>& hashes) {
int64_t num_blocks = 1LL << target->log_blocks();
uint32_t num_blocks = 1 << target->log_blocks();
for (size_t i = 0; i < group_ids.size(); ++i) {
std::ignore = InsertNewGroup(target, group_ids[i], hashes[i], num_blocks);
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/acero/swiss_join_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,8 @@ class SwissTableMerge {
// Max block id value greater or equal to the number of blocks guarantees that
// the search will not be stopped.
//
static inline bool InsertNewGroup(SwissTable* target, uint64_t group_id, uint32_t hash,
int64_t max_block_id);
static inline bool InsertNewGroup(SwissTable* target, uint32_t group_id, uint32_t hash,
uint32_t max_block_id);
};

struct SwissTableWithKeys {
Expand Down
50 changes: 23 additions & 27 deletions cpp/src/arrow/compute/key_map_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,17 @@ void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selec
} else {
int64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
int64_t num_groupid_bytes = num_groupid_bits / 8;
uint32_t mask = num_groupid_bytes == 1 ? 0xFF
: num_groupid_bytes == 2 ? 0xFFFF
: 0xFFFFFFFF;
uint32_t group_id_mask = group_id_mask_from_num_groupid_bits(num_groupid_bits);
int64_t num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits);
const uint8_t* slots_base = blocks_->data() + bytes_status_in_block_;

for (int i = 0; i < num_keys; ++i) {
uint32_t id = use_selection ? selection[i] : i;
uint32_t hash = hashes[id];
uint32_t block_id = hash >> (bits_hash_ - log_blocks_);
uint32_t block_id = block_id_from_hash(hash, log_blocks_);
uint32_t group_id = *reinterpret_cast<const uint32_t*>(
slots_base + block_id * num_block_bytes + local_slots[id] * num_groupid_bytes);
group_id &= mask;
group_id &= group_id_mask;
out_group_ids[id] = group_id;
}
}
Expand Down Expand Up @@ -163,9 +161,9 @@ void SwissTable::init_slot_ids(const int num_keys, const uint16_t* selection,
for (int i = 0; i < num_keys; ++i) {
uint16_t id = selection[i];
uint32_t hash = hashes[id];
uint32_t iblock = hash >> (bits_hash_ - log_blocks_);
uint32_t iblock = block_id_from_hash(hash, log_blocks_);
uint32_t match = ::arrow::bit_util::GetBit(match_bitvector, id) ? 1 : 0;
uint32_t slot_id = iblock * 8u + local_slots[id] + match;
uint32_t slot_id = global_slot_id(iblock, local_slots[id] + match);
out_slot_ids[id] = slot_id;
}
}
Expand All @@ -188,7 +186,7 @@ void SwissTable::init_slot_ids_for_new_keys(uint32_t num_ids, const uint16_t* id
for (uint32_t i = 0; i < num_ids; ++i) {
int id = ids[i];
uint32_t hash = hashes[id];
uint32_t iblock = hash >> (bits_hash_ - log_blocks_);
uint32_t iblock = block_id_from_hash(hash, log_blocks_);
uint64_t block;
for (;;) {
block = *reinterpret_cast<const uint64_t*>(blocks_->mutable_data() +
Expand All @@ -200,7 +198,7 @@ void SwissTable::init_slot_ids_for_new_keys(uint32_t num_ids, const uint16_t* id
iblock = (iblock + 1) & ((1 << log_blocks_) - 1);
}
uint32_t empty_slot = static_cast<int>(8 - ARROW_POPCOUNT64(block));
slot_ids[id] = iblock * 8u + empty_slot;
slot_ids[id] = global_slot_id(iblock, empty_slot);
}
}
}
Expand Down Expand Up @@ -260,8 +258,8 @@ uint64_t SwissTable::num_groups_for_resize() const {
}
}

uint64_t SwissTable::wrap_global_slot_id(uint64_t global_slot_id) const {
uint64_t global_slot_id_mask = (1 << (log_blocks_ + 3)) - 1;
uint32_t SwissTable::wrap_global_slot_id(uint32_t global_slot_id) const {
uint32_t global_slot_id_mask = static_cast<uint32_t>((1ULL << (log_blocks_ + 3)) - 1);
return global_slot_id & global_slot_id_mask;
}

Expand Down Expand Up @@ -364,18 +362,17 @@ bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_sl
constexpr uint64_t stamp_mask = 0x7f;
const int stamp =
static_cast<int>((hash >> bits_shift_for_block_and_stamp_) & stamp_mask);
uint64_t start_slot_id = wrap_global_slot_id(in_slot_id);
uint32_t start_slot_id = wrap_global_slot_id(in_slot_id);
int match_found;
int local_slot;
uint8_t* blockbase;
for (;;) {
blockbase = blocks_->mutable_data() + num_block_bytes * (start_slot_id >> 3);
uint64_t block = *reinterpret_cast<uint64_t*>(blockbase);

search_block<true>(block, stamp, (start_slot_id & 7), &local_slot, &match_found);
search_block<true>(block, stamp, start_slot_id & 7, &local_slot, &match_found);

start_slot_id =
wrap_global_slot_id((start_slot_id & ~7ULL) + local_slot + match_found);
start_slot_id = wrap_global_slot_id(start_slot_id & ~7U + local_slot + match_found);

// Match found can be 1 in two cases:
// - match was found
Expand All @@ -386,10 +383,8 @@ bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_sl
}
}

const uint64_t groupid_mask = (1ULL << num_groupid_bits) - 1;
*out_group_id =
static_cast<uint32_t>(extract_group_id(blockbase, local_slot, groupid_mask));
*out_slot_id = static_cast<uint32_t>(start_slot_id);
*out_group_id = extract_group_id(blockbase, local_slot, num_groupid_bits);
*out_slot_id = start_slot_id;

return match_found;
}
Expand Down Expand Up @@ -608,7 +603,8 @@ Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t*
for (uint32_t i = 0; i < num_ids; ++i) {
// First slot in the new starting block
const int16_t id = ids[i];
slot_ids[id] = (hashes[id] >> (bits_hash_ - log_blocks_)) * 8u;
uint32_t block_id = block_id_from_hash(hashes[id], log_blocks_);
slot_ids[id] = global_slot_id(block_id, 0);
}
}
} while (num_ids > 0);
Expand Down Expand Up @@ -656,21 +652,21 @@ Status SwissTable::grow_double() {
uint64_t block = *reinterpret_cast<const uint64_t*>(block_base);

uint32_t full_slots = CountLeadingZeros(block & kHighBitOfEachByte) >> 3;
int full_slots_new[2];
uint32_t full_slots_new[2];
full_slots_new[0] = full_slots_new[1] = 0;
util::SafeStore(double_block_base_new, kHighBitOfEachByte);
util::SafeStore(double_block_base_new + block_size_after, kHighBitOfEachByte);

for (uint32_t j = 0; j < full_slots; ++j) {
uint64_t slot_id = i * 8u + j;
uint64_t slot_id = global_slot_id(i, j);
uint32_t hash = hashes()[slot_id];
uint32_t block_id_new = hash >> (bits_hash_ - log_blocks_after);
uint32_t block_id_new = block_id_from_hash(hash, log_blocks_after);
bool is_overflow_entry = ((block_id_new >> 1) != static_cast<uint64_t>(i));
if (is_overflow_entry) {
continue;
}

int ihalf = block_id_new & 1;
uint32_t ihalf = block_id_new & 1;
uint8_t stamp_new = (hash >> bits_shift_for_block_and_stamp_after) & stamp_mask;
uint64_t group_id_bit_offs = j * num_group_id_bits_before;
uint64_t group_id =
Expand All @@ -679,7 +675,7 @@ Status SwissTable::grow_double() {
(group_id_bit_offs & 7)) &
group_id_mask_before;

uint64_t slot_id_new = i * 16u + ihalf * 8u + full_slots_new[ihalf];
uint64_t slot_id_new = global_slot_id(i * 2 + ihalf, full_slots_new[ihalf]);
hashes_new[slot_id_new] = hash;
uint8_t* block_base_new = double_block_base_new + ihalf * block_size_after;
block_base_new[7 - full_slots_new[ihalf]] = stamp_new;
Expand All @@ -701,9 +697,9 @@ Status SwissTable::grow_double() {
uint32_t full_slots = CountLeadingZeros(block & kHighBitOfEachByte) >> 3;

for (uint32_t j = 0; j < full_slots; ++j) {
uint64_t slot_id = i * 8u + j;
uint64_t slot_id = global_slot_id(i, j);
uint32_t hash = hashes()[slot_id];
uint32_t block_id_new = hash >> (bits_hash_ - log_blocks_after);
uint32_t block_id_new = block_id_from_hash(hash, log_blocks_after);
bool is_overflow_entry = ((block_id_new >> 1) != static_cast<uint64_t>(i));
if (!is_overflow_entry) {
continue;
Expand Down
60 changes: 40 additions & 20 deletions cpp/src/arrow/compute/key_map_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,25 @@ class ARROW_EXPORT SwissTable {
return reinterpret_cast<uint32_t*>(hashes_->mutable_data());
}

inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id);

/// \brief Extract group id for a given slot in a given block.
///
inline uint64_t extract_group_id(const uint8_t* block_ptr, int slot,
uint64_t group_id_mask) const;
inline static uint32_t extract_group_id(const uint8_t* block_ptr, int local_slot,
int64_t num_group_id_bits) {
uint32_t group_id_mask = group_id_mask_from_num_groupid_bits(num_group_id_bits);
uint32_t group_id = *reinterpret_cast<const uint32_t*>(
block_ptr + bytes_status_in_block_ + local_slot * num_group_id_bits / 8);
return group_id & group_id_mask;
}

inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id);
inline static uint32_t block_id_from_hash(uint32_t hash, int log_blocks) {
return hash >> (bits_hash_ - log_blocks);
}

inline static uint32_t global_slot_id(uint32_t block_id, uint32_t local_slot_id) {
return block_id * static_cast<uint32_t>(kSlotsPerBlock) + local_slot_id;
}

static int64_t num_groupid_bits_from_log_blocks(int log_blocks) {
int required_bits = log_blocks + 3;
Expand All @@ -106,6 +119,8 @@ class ARROW_EXPORT SwissTable {
return num_groupid_bits + bytes_status_in_block_;
}

static constexpr int kSlotsPerBlock = 8;

// Use 32-bit hash for now
static constexpr int bits_hash_ = 32;

Expand Down Expand Up @@ -153,7 +168,7 @@ class ARROW_EXPORT SwissTable {

inline uint64_t num_groups_for_resize() const;

inline uint64_t wrap_global_slot_id(uint64_t global_slot_id) const;
inline uint32_t wrap_global_slot_id(uint32_t global_slot_id) const;

void init_slot_ids(const int num_keys, const uint16_t* selection,
const uint32_t* hashes, const uint8_t* local_slots,
Expand Down Expand Up @@ -205,6 +220,10 @@ class ARROW_EXPORT SwissTable {
// Resize large hash tables when 75% full.
Status grow_double();

static uint32_t group_id_mask_from_num_groupid_bits(int64_t num_groupid_bits) {
return static_cast<uint32_t>((1ULL << num_groupid_bits) - 1);
}

static constexpr int bytes_status_in_block_ = 8;

// Number of hash bits stored in slots in a block.
Expand Down Expand Up @@ -247,22 +266,23 @@ class ARROW_EXPORT SwissTable {
MemoryPool* pool_;
};

uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot,
uint64_t group_id_mask) const {
// Group id values for all 8 slots in the block are bit-packed and follow the status
// bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In
// that case we can extract group id using aligned 64-bit word access.
int num_group_id_bits = static_cast<int>(ARROW_POPCOUNT64(group_id_mask));
assert(num_group_id_bits == 8 || num_group_id_bits == 16 || num_group_id_bits == 32 ||
num_group_id_bits == 64);

int bit_offset = slot * num_group_id_bits;
const uint64_t* group_id_bytes =
reinterpret_cast<const uint64_t*>(block_ptr) + 1 + (bit_offset >> 6);
uint64_t group_id = (*group_id_bytes >> (bit_offset & 63)) & group_id_mask;

return group_id;
}
// uint32_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot,
// uint64_t group_id_mask) const {
// // Group id values for all 8 slots in the block are bit-packed and follow the status
// // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In
// // that case we can extract group id using aligned 64-bit word access.
// int num_group_id_bits = static_cast<int>(ARROW_POPCOUNT64(group_id_mask));
// assert(num_group_id_bits == 8 || num_group_id_bits == 16 || num_group_id_bits == 32
// ||
// num_group_id_bits == 64);

// int bit_offset = slot * num_group_id_bits;
// const uint64_t* group_id_bytes =
// reinterpret_cast<const uint64_t*>(block_ptr) + 1 + (bit_offset >> 6);
// uint64_t group_id = (*group_id_bytes >> (bit_offset & 63)) & group_id_mask;

// return group_id;
// }

void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash,
uint32_t group_id) {
Expand Down

0 comments on commit bcc43da

Please sign in to comment.