From 3da78ec73d905e8650998fcc70bebfc1ea84cbc8 Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Fri, 27 Dec 2024 02:35:46 +0800 Subject: [PATCH] Fix --- cpp/src/arrow/compute/key_map_internal_avx2.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/compute/key_map_internal_avx2.cc b/cpp/src/arrow/compute/key_map_internal_avx2.cc index a5eae01c044b7..7ed46dbb05731 100644 --- a/cpp/src/arrow/compute/key_map_internal_avx2.cc +++ b/cpp/src/arrow/compute/key_map_internal_avx2.cc @@ -390,6 +390,9 @@ int SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashe } else { int64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); int64_t num_groupid_bytes = num_groupid_bits / 8; + uint32_t mask = num_groupid_bytes == 1 ? 0xFF + : num_groupid_bytes == 2 ? 0xFFFF + : 0xFFFFFFFF; int64_t num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits); const int* slots_base = reinterpret_cast(blocks_->data() + bytes_status_); @@ -412,16 +415,17 @@ int SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashe __m256i slot_offset_lo = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(block_id)); __m256i slot_offset_hi = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(block_id, 1)); - slot_offset_lo = _mm256_mul_epi32( - slot_offset_lo, _mm256_set1_epi32(static_cast(num_block_bytes))); - slot_offset_hi = _mm256_mul_epi32( - slot_offset_hi, _mm256_set1_epi32(static_cast(num_block_bytes))); + slot_offset_lo = + _mm256_mul_epi32(slot_offset_lo, _mm256_set1_epi64x(num_block_bytes)); + slot_offset_hi = + _mm256_mul_epi32(slot_offset_hi, _mm256_set1_epi64x(num_block_bytes)); slot_offset_lo = _mm256_add_epi64(slot_offset_lo, local_slot_lo); slot_offset_hi = _mm256_add_epi64(slot_offset_hi, local_slot_hi); __m128i group_id_lo = _mm256_i64gather_epi32(slots_base, slot_offset_lo, 1); __m128i group_id_hi = _mm256_i64gather_epi32(slots_base, slot_offset_hi, 1); __m256i group_id = _mm256_set_m128i(group_id_hi, group_id_lo); + group_id = _mm256_and_si256(group_id, _mm256_set1_epi32(mask)); _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id); }