Skip to content

Commit 8784104

Browse files
Add Neon implementation of std::swap_ranges (#5819)
Co-authored-by: Stephan T. Lavavej <[email protected]>
1 parent 0609cbf commit 8784104

File tree

2 files changed

+96
-11
lines changed

2 files changed

+96
-11
lines changed

stl/inc/xutility

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ _STL_DISABLE_CLANG_WARNINGS
100100
#define _VECTORIZED_ROTATE _VECTORIZED_FOR_X64_X86
101101
#define _VECTORIZED_SEARCH _VECTORIZED_FOR_X64_X86
102102
#define _VECTORIZED_SEARCH_N _VECTORIZED_FOR_X64_X86
103-
#define _VECTORIZED_SWAP_RANGES _VECTORIZED_FOR_X64_X86
103+
#define _VECTORIZED_SWAP_RANGES _VECTORIZED_FOR_X64_X86_ARM64
104104
#define _VECTORIZED_UNIQUE _VECTORIZED_FOR_X64_X86
105105
#define _VECTORIZED_UNIQUE_COPY _VECTORIZED_FOR_X64_X86
106106

stl/src/vector_algorithms.cpp

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,29 @@
55
#error _M_CEE_PURE should not be defined when compiling vector_algorithms.cpp.
66
#endif
77

8-
#if defined(_M_IX86) || defined(_M_X64) // NB: includes _M_ARM64EC
8+
#ifndef _DEBUG
9+
#pragma optimize("t", on) // TRANSITION, GH-2108: Override /Os with /Ot for this TU before any function definitions
10+
#endif
11+
912
#include <__msvc_minmax.hpp>
1013
#include <cstdint>
1114
#include <cstring>
1215
#include <cwchar>
1316
#include <type_traits>
1417

15-
#ifndef _M_ARM64EC
18+
#if !defined(_M_ARM64) && !defined(_M_ARM64EC)
1619
#include <intrin.h>
1720
#include <isa_availability.h>
1821

1922
extern "C" long __isa_enabled;
23+
#endif // ^^^ !defined(_M_ARM64) && !defined(_M_ARM64EC) ^^^
2024

21-
#ifndef _DEBUG
22-
#pragma optimize("t", on) // Override /Os with /Ot for this TU
23-
#endif // !defined(_DEBUG)
24-
#endif // ^^^ !defined(_M_ARM64EC) ^^^
25+
#ifdef _M_ARM64
26+
#include <arm64_neon.h>
27+
#endif
2528

2629
namespace {
27-
#ifndef _M_ARM64EC
30+
#if !defined(_M_ARM64) && !defined(_M_ARM64EC)
2831
bool _Use_avx2() noexcept {
2932
return __isa_enabled & (1 << __ISA_AVAILABLE_AVX2);
3033
}
@@ -51,7 +54,7 @@ namespace {
5154
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
5255
reinterpret_cast<const unsigned char*>(_Tail_masks) + (32 - _Count_in_bytes)));
5356
}
54-
#endif // ^^^ !defined(_M_ARM64EC) ^^^
57+
#endif // ^^^ !defined(_M_ARM64) && !defined(_M_ARM64EC) ^^^
5558

5659
size_t _Byte_length(const void* const _First, const void* const _Last) noexcept {
5760
return static_cast<const unsigned char*>(_Last) - static_cast<const unsigned char*>(_First);
@@ -78,6 +81,86 @@ namespace {
7881

7982
extern "C" {
8083

84+
#ifdef _M_ARM64
85+
__declspec(noalias) void __cdecl __std_swap_ranges_trivially_swappable_noalias(
86+
void* _First1, void* const _Last1, void* _First2) noexcept {
87+
if (_Byte_length(_First1, _Last1) >= 64) {
88+
constexpr size_t _Mask_64 = ~((static_cast<size_t>(1) << 6) - 1);
89+
const void* _Stop_at = _First1;
90+
_Advance_bytes(_Stop_at, _Byte_length(_First1, _Last1) & _Mask_64);
91+
do {
92+
const uint8x16_t _Left1 = vld1q_u8(static_cast<uint8_t*>(_First1) + 0);
93+
const uint8x16_t _Left2 = vld1q_u8(static_cast<uint8_t*>(_First1) + 16);
94+
const uint8x16_t _Left3 = vld1q_u8(static_cast<uint8_t*>(_First1) + 32);
95+
const uint8x16_t _Left4 = vld1q_u8(static_cast<uint8_t*>(_First1) + 48);
96+
const uint8x16_t _Right1 = vld1q_u8(static_cast<uint8_t*>(_First2) + 0);
97+
const uint8x16_t _Right2 = vld1q_u8(static_cast<uint8_t*>(_First2) + 16);
98+
const uint8x16_t _Right3 = vld1q_u8(static_cast<uint8_t*>(_First2) + 32);
99+
const uint8x16_t _Right4 = vld1q_u8(static_cast<uint8_t*>(_First2) + 48);
100+
vst1q_u8(static_cast<uint8_t*>(_First1) + 0, _Right1);
101+
vst1q_u8(static_cast<uint8_t*>(_First1) + 16, _Right2);
102+
vst1q_u8(static_cast<uint8_t*>(_First1) + 32, _Right3);
103+
vst1q_u8(static_cast<uint8_t*>(_First1) + 48, _Right4);
104+
vst1q_u8(static_cast<uint8_t*>(_First2) + 0, _Left1);
105+
vst1q_u8(static_cast<uint8_t*>(_First2) + 16, _Left2);
106+
vst1q_u8(static_cast<uint8_t*>(_First2) + 32, _Left3);
107+
vst1q_u8(static_cast<uint8_t*>(_First2) + 48, _Left4);
108+
_Advance_bytes(_First1, 64);
109+
_Advance_bytes(_First2, 64);
110+
} while (_First1 != _Stop_at);
111+
}
112+
113+
if (_Byte_length(_First1, _Last1) >= 32) {
114+
const uint8x16_t _Left1 = vld1q_u8(static_cast<uint8_t*>(_First1) + 0);
115+
const uint8x16_t _Left2 = vld1q_u8(static_cast<uint8_t*>(_First1) + 16);
116+
const uint8x16_t _Right1 = vld1q_u8(static_cast<uint8_t*>(_First2) + 0);
117+
const uint8x16_t _Right2 = vld1q_u8(static_cast<uint8_t*>(_First2) + 16);
118+
vst1q_u8(static_cast<uint8_t*>(_First1) + 0, _Right1);
119+
vst1q_u8(static_cast<uint8_t*>(_First1) + 16, _Right2);
120+
vst1q_u8(static_cast<uint8_t*>(_First2) + 0, _Left1);
121+
vst1q_u8(static_cast<uint8_t*>(_First2) + 16, _Left2);
122+
_Advance_bytes(_First1, 32);
123+
_Advance_bytes(_First2, 32);
124+
}
125+
126+
if (_Byte_length(_First1, _Last1) >= 16) {
127+
const uint8x16_t _Left = vld1q_u8(static_cast<uint8_t*>(_First1));
128+
const uint8x16_t _Right = vld1q_u8(static_cast<uint8_t*>(_First2));
129+
vst1q_u8(static_cast<uint8_t*>(_First1), _Right);
130+
vst1q_u8(static_cast<uint8_t*>(_First2), _Left);
131+
_Advance_bytes(_First1, 16);
132+
_Advance_bytes(_First2, 16);
133+
}
134+
135+
if (_Byte_length(_First1, _Last1) >= 8) {
136+
const uint8x8_t _Left = vld1_u8(static_cast<uint8_t*>(_First1));
137+
const uint8x8_t _Right = vld1_u8(static_cast<uint8_t*>(_First2));
138+
vst1_u8(static_cast<uint8_t*>(_First1), _Right);
139+
vst1_u8(static_cast<uint8_t*>(_First2), _Left);
140+
_Advance_bytes(_First1, 8);
141+
_Advance_bytes(_First2, 8);
142+
}
143+
144+
if (_Byte_length(_First1, _Last1) >= 4) {
145+
uint32x2_t _Left = vdup_n_u32(0);
146+
uint32x2_t _Right = vdup_n_u32(0);
147+
_Left = vld1_lane_u32(static_cast<uint32_t*>(_First1), _Left, 0);
148+
_Right = vld1_lane_u32(static_cast<uint32_t*>(_First2), _Right, 0);
149+
vst1_lane_u32(static_cast<uint32_t*>(_First1), _Right, 0);
150+
vst1_lane_u32(static_cast<uint32_t*>(_First2), _Left, 0);
151+
_Advance_bytes(_First1, 4);
152+
_Advance_bytes(_First2, 4);
153+
}
154+
155+
auto _First1c = static_cast<unsigned char*>(_First1);
156+
auto _First2c = static_cast<unsigned char*>(_First2);
157+
for (; _First1c != _Last1; ++_First1c, ++_First2c) {
158+
const unsigned char _Ch = *_First1c;
159+
*_First1c = *_First2c;
160+
*_First2c = _Ch;
161+
}
162+
}
163+
#else // ^^^ defined(_M_ARM64) / !defined(_M_ARM64) vvv
81164
__declspec(noalias) void __cdecl __std_swap_ranges_trivially_swappable_noalias(
82165
void* _First1, void* const _Last1, void* _First2) noexcept {
83166
#ifndef _M_ARM64EC
@@ -157,15 +240,17 @@ __declspec(noalias) void __cdecl __std_swap_ranges_trivially_swappable_noalias(
157240
}
158241
}
159242

160-
// TRANSITION, ABI: __std_swap_ranges_trivially_swappable() is preserved for binary compatibility
243+
// TRANSITION, ABI: __std_swap_ranges_trivially_swappable() is preserved for binary compatibility (x64/x86/ARM64EC)
161244
void* __cdecl __std_swap_ranges_trivially_swappable(
162245
void* const _First1, void* const _Last1, void* const _First2) noexcept {
163246
__std_swap_ranges_trivially_swappable_noalias(_First1, _Last1, _First2);
164247
return static_cast<char*>(_First2) + (static_cast<char*>(_Last1) - static_cast<char*>(_First1));
165248
}
249+
#endif // ^^^ !defined(_M_ARM64) ^^^
166250

167251
} // extern "C"
168252

253+
#ifndef _M_ARM64
169254
namespace {
170255
namespace _Rotating {
171256
void _Swap_3_ranges(void* _First1, void* const _Last1, void* _First2, void* _First3) noexcept {
@@ -7694,4 +7779,4 @@ __declspec(noalias) bool __stdcall __std_bitset_from_string_2(void* const _Dest,
76947779
}
76957780

76967781
} // extern "C"
7697-
#endif // defined(_M_IX86) || defined(_M_X64)
7782+
#endif // ^^^ !defined(_M_ARM64) ^^^

0 commit comments

Comments
 (0)