Skip to content

Commit ccd75ec

Browse files
committed
Added x86 SIMD optimizations to crypto datatypes.
- The v128 operations are optimized for SSE2/SSSE3. - srtp_octet_string_is_eq is optimized for SSE2. When SSE2 is not available, use a pair of 32-bit accumulators to speed up the bulk of the operation. We use two accumulators to leverage instruction-level parallelism supported by most modern CPUs. - In srtp_cleanse, use memset and ensure it is not optimized away with a dummy asm statement, which can potentially consume the contents of the memory. - Endian conversion functions use gcc-style intrinsics, when possible. The SIMD code uses intrinsics, which are available on all modern compilers. For MSVC, config_in_cmake.h is modified to define gcc/clang-style SSE macros based on MSVC predefined macros. We enable all SSE versions when it indicates that AVX is enabled. SSE2 is always enabled for x86-64 or for x86 when SSE2 FP math is enabled.
1 parent 19e6a05 commit ccd75ec

File tree

3 files changed

+249
-8
lines changed

3 files changed

+249
-8
lines changed

config_in_cmake.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,14 @@
119119
#define inline
120120
#endif
121121
#endif
122+
123+
/* Define gcc/clang-style SSE macros on compilers that don't define them (primarilly, MSVC). */
124+
#if !defined(__SSE2__) && (defined(_M_X64) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2))
125+
#define __SSE2__
126+
#endif
127+
#if !defined(__SSSE3__) && defined(__AVX__)
128+
#define __SSSE3__
129+
#endif
130+
#if !defined(__SSE4_1__) && defined(__AVX__)
131+
#define __SSE4_1__
132+
#endif

crypto/include/datatypes.h

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@
6262
#error "Platform not recognized"
6363
#endif
6464

65+
#if defined(__SSE2__)
66+
#include <emmintrin.h>
67+
#endif
68+
6569
#ifdef __cplusplus
6670
extern "C" {
6771
#endif
@@ -90,6 +94,26 @@ void v128_left_shift(v128_t *x, int shift_index);
9094
*
9195
*/
9296

97+
#if defined(__SSE2__)
98+
99+
#define v128_set_to_zero(x) \
100+
(_mm_storeu_si128((__m128i *)(x), _mm_setzero_si128()))
101+
102+
#define v128_copy(x, y) \
103+
(_mm_storeu_si128((__m128i *)(x), _mm_loadu_si128((const __m128i *)(y))))
104+
105+
#define v128_xor(z, x, y) \
106+
(_mm_storeu_si128((__m128i *)(z), \
107+
_mm_xor_si128(_mm_loadu_si128((const __m128i *)(x)), \
108+
_mm_loadu_si128((const __m128i *)(y)))))
109+
110+
#define v128_xor_eq(z, x) \
111+
(_mm_storeu_si128((__m128i *)(z), \
112+
_mm_xor_si128(_mm_loadu_si128((const __m128i *)(x)), \
113+
_mm_loadu_si128((const __m128i *)(z)))))
114+
115+
#else /* defined(__SSE2__) */
116+
93117
#define v128_set_to_zero(x) \
94118
((x)->v32[0] = 0, (x)->v32[1] = 0, (x)->v32[2] = 0, (x)->v32[3] = 0)
95119

@@ -113,6 +137,8 @@ void v128_left_shift(v128_t *x, int shift_index);
113137
((z)->v64[0] ^= (x)->v64[0], (z)->v64[1] ^= (x)->v64[1])
114138
#endif
115139

140+
#endif /* defined(__SSE2__) */
141+
116142
/* NOTE! This assumes an odd ordering! */
117143
/* This will not be compatible directly with math on some processors */
118144
/* bit 0 is first 32-bit word, low order bit. in little-endian, that's
@@ -168,13 +194,11 @@ void octet_string_set_to_zero(void *s, size_t len);
168194
#define be64_to_cpu(x) bswap_64((x))
169195
#else /* WORDS_BIGENDIAN */
170196

171-
#if defined(__GNUC__) && (defined(HAVE_X86) || defined(__x86_64__))
197+
#if defined(__GNUC__)
172198
/* Fall back. */
173199
static inline uint32_t be32_to_cpu(uint32_t v)
174200
{
175-
/* optimized for x86. */
176-
asm("bswap %0" : "=r"(v) : "0"(v));
177-
return v;
201+
return __builtin_bswap32(v);
178202
}
179203
#else /* HAVE_X86 */
180204
#ifdef HAVE_NETINET_IN_H
@@ -187,7 +211,9 @@ static inline uint32_t be32_to_cpu(uint32_t v)
187211

188212
static inline uint64_t be64_to_cpu(uint64_t v)
189213
{
190-
#ifdef NO_64BIT_MATH
214+
#if defined(__GNUC__)
215+
v = __builtin_bswap64(v);
216+
#elif defined(NO_64BIT_MATH)
191217
/* use the make64 functions to do 64-bit math */
192218
v = make64(htonl(low32(v)), htonl(high32(v)));
193219
#else /* NO_64BIT_MATH */

crypto/math/datatypes.c

Lines changed: 207 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@
5353

5454
#include "datatypes.h"
5555

56+
#if defined(__SSE2__)
57+
#include <tmmintrin.h>
58+
#endif
59+
60+
#if defined(_MSC_VER)
61+
#define ALIGNMENT(N) __declspec(align(N))
62+
#else
63+
#define ALIGNMENT(N) __attribute__((aligned(N)))
64+
#endif
65+
5666
/*
5767
* bit_string is a buffer that is used to hold output strings, e.g.
5868
* for printing.
@@ -123,6 +133,9 @@ char *v128_bit_string(v128_t *x)
123133

124134
void v128_copy_octet_string(v128_t *x, const uint8_t s[16])
125135
{
136+
#if defined(__SSE2__)
137+
_mm_storeu_si128((__m128i *)(x), _mm_loadu_si128((const __m128i *)(s)));
138+
#else
126139
#ifdef ALIGNMENT_32BIT_REQUIRED
127140
if ((((uint32_t)&s[0]) & 0x3) != 0)
128141
#endif
@@ -151,8 +164,67 @@ void v128_copy_octet_string(v128_t *x, const uint8_t s[16])
151164
v128_copy(x, v);
152165
}
153166
#endif
167+
#endif /* defined(__SSE2__) */
168+
}
169+
170+
#if defined(__SSSE3__)
171+
172+
/* clang-format off */
173+
174+
ALIGNMENT(16)
175+
static const uint8_t right_shift_masks[5][16] = {
176+
{ 0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u,
177+
8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u },
178+
{ 0x80, 0x80, 0x80, 0x80, 0u, 1u, 2u, 3u,
179+
4u, 5u, 6u, 7u, 8u, 9u, 10u, 11u },
180+
{ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
181+
0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u },
182+
{ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
183+
0x80, 0x80, 0x80, 0x80, 0u, 1u, 2u, 3u },
184+
/* needed for bitvector_left_shift */
185+
{ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
186+
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 }
187+
};
188+
189+
ALIGNMENT(16)
190+
static const uint8_t left_shift_masks[4][16] = {
191+
{ 0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u,
192+
8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u },
193+
{ 4u, 5u, 6u, 7u, 8u, 9u, 10u, 11u,
194+
12u, 13u, 14u, 15u, 0x80, 0x80, 0x80, 0x80 },
195+
{ 8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u,
196+
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 },
197+
{ 12u, 13u, 14u, 15u, 0x80, 0x80, 0x80, 0x80,
198+
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 }
199+
};
200+
201+
/* clang-format on */
202+
203+
void v128_left_shift(v128_t *x, int shift)
204+
{
205+
if (shift > 127) {
206+
v128_set_to_zero(x);
207+
return;
208+
}
209+
210+
const int base_index = shift >> 5;
211+
const int bit_index = shift & 31;
212+
213+
__m128i mm = _mm_loadu_si128((const __m128i *)x);
214+
__m128i mm_shift_right = _mm_cvtsi32_si128(bit_index);
215+
__m128i mm_shift_left = _mm_cvtsi32_si128(32 - bit_index);
216+
mm = _mm_shuffle_epi8(mm, ((const __m128i *)left_shift_masks)[base_index]);
217+
218+
__m128i mm1 = _mm_srl_epi32(mm, mm_shift_right);
219+
__m128i mm2 = _mm_sll_epi32(mm, mm_shift_left);
220+
mm2 = _mm_srli_si128(mm2, 4);
221+
mm1 = _mm_or_si128(mm1, mm2);
222+
223+
_mm_storeu_si128((__m128i *)x, mm1);
154224
}
155225

226+
#else /* defined(__SSSE3__) */
227+
156228
void v128_left_shift(v128_t *x, int shift)
157229
{
158230
int i;
@@ -179,6 +251,8 @@ void v128_left_shift(v128_t *x, int shift)
179251
x->v32[i] = 0;
180252
}
181253

254+
#endif /* defined(__SSSE3__) */
255+
182256
/* functions manipulating bitvector_t */
183257

184258
int bitvector_alloc(bitvector_t *v, unsigned long length)
@@ -190,6 +264,7 @@ int bitvector_alloc(bitvector_t *v, unsigned long length)
190264
(length + bits_per_word - 1) & ~(unsigned long)((bits_per_word - 1));
191265

192266
l = length / bits_per_word * bytes_per_word;
267+
l = (l + 15ul) & ~15ul;
193268

194269
/* allocate memory, then set parameters */
195270
if (l == 0) {
@@ -225,6 +300,73 @@ void bitvector_set_to_zero(bitvector_t *x)
225300
memset(x->word, 0, x->length >> 3);
226301
}
227302

303+
#if defined(__SSSE3__)
304+
305+
void bitvector_left_shift(bitvector_t *x, int shift)
306+
{
307+
if ((uint32_t)shift >= x->length) {
308+
bitvector_set_to_zero(x);
309+
return;
310+
}
311+
312+
const int base_index = shift >> 5;
313+
const int bit_index = shift & 31;
314+
const int vec_length = (x->length + 127u) >> 7;
315+
const __m128i *from = ((const __m128i *)x->word) + (base_index >> 2);
316+
__m128i *to = (__m128i *)x->word;
317+
__m128i *const end = to + vec_length;
318+
319+
__m128i mm_right_shift_mask =
320+
((const __m128i *)right_shift_masks)[4u - (base_index & 3u)];
321+
__m128i mm_left_shift_mask =
322+
((const __m128i *)left_shift_masks)[base_index & 3u];
323+
__m128i mm_shift_right = _mm_cvtsi32_si128(bit_index);
324+
__m128i mm_shift_left = _mm_cvtsi32_si128(32 - bit_index);
325+
326+
__m128i mm_current = _mm_loadu_si128(from);
327+
__m128i mm_current_r = _mm_srl_epi32(mm_current, mm_shift_right);
328+
__m128i mm_current_l = _mm_sll_epi32(mm_current, mm_shift_left);
329+
330+
while ((end - from) >= 2) {
331+
++from;
332+
__m128i mm_next = _mm_loadu_si128(from);
333+
334+
__m128i mm_next_r = _mm_srl_epi32(mm_next, mm_shift_right);
335+
__m128i mm_next_l = _mm_sll_epi32(mm_next, mm_shift_left);
336+
mm_current_l = _mm_alignr_epi8(mm_next_l, mm_current_l, 4);
337+
mm_current = _mm_or_si128(mm_current_r, mm_current_l);
338+
339+
mm_current = _mm_shuffle_epi8(mm_current, mm_left_shift_mask);
340+
341+
__m128i mm_temp_next = _mm_srli_si128(mm_next_l, 4);
342+
mm_temp_next = _mm_or_si128(mm_next_r, mm_temp_next);
343+
344+
mm_temp_next = _mm_shuffle_epi8(mm_temp_next, mm_right_shift_mask);
345+
mm_current = _mm_or_si128(mm_temp_next, mm_current);
346+
347+
_mm_storeu_si128(to, mm_current);
348+
++to;
349+
350+
mm_current_r = mm_next_r;
351+
mm_current_l = mm_next_l;
352+
}
353+
354+
mm_current_l = _mm_srli_si128(mm_current_l, 4);
355+
mm_current = _mm_or_si128(mm_current_r, mm_current_l);
356+
357+
mm_current = _mm_shuffle_epi8(mm_current, mm_left_shift_mask);
358+
359+
_mm_storeu_si128(to, mm_current);
360+
++to;
361+
362+
while (to < end) {
363+
_mm_storeu_si128(to, _mm_setzero_si128());
364+
++to;
365+
}
366+
}
367+
368+
#else /* defined(__SSSE3__) */
369+
228370
void bitvector_left_shift(bitvector_t *x, int shift)
229371
{
230372
int i;
@@ -253,16 +395,73 @@ void bitvector_left_shift(bitvector_t *x, int shift)
253395
x->word[i] = 0;
254396
}
255397

398+
#endif /* defined(__SSSE3__) */
399+
256400
int srtp_octet_string_is_eq(uint8_t *a, uint8_t *b, int len)
257401
{
258-
uint8_t *end = b + len;
259-
uint8_t accumulator = 0;
260-
261402
/*
262403
* We use this somewhat obscure implementation to try to ensure the running
263404
* time only depends on len, even accounting for compiler optimizations.
264405
* The accumulator ends up zero iff the strings are equal.
265406
*/
407+
uint8_t *end = b + len;
408+
uint32_t accumulator = 0;
409+
410+
#if defined(__SSE2__)
411+
__m128i mm_accumulator1 = _mm_setzero_si128();
412+
__m128i mm_accumulator2 = _mm_setzero_si128();
413+
for (int i = 0, n = len >> 5; i < n; ++i, a += 32, b += 32) {
414+
__m128i mm_a1 = _mm_loadu_si128((const __m128i *)a);
415+
__m128i mm_b1 = _mm_loadu_si128((const __m128i *)b);
416+
__m128i mm_a2 = _mm_loadu_si128((const __m128i *)(a + 16));
417+
__m128i mm_b2 = _mm_loadu_si128((const __m128i *)(b + 16));
418+
mm_a1 = _mm_xor_si128(mm_a1, mm_b1);
419+
mm_a2 = _mm_xor_si128(mm_a2, mm_b2);
420+
mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_a1);
421+
mm_accumulator2 = _mm_or_si128(mm_accumulator2, mm_a2);
422+
}
423+
424+
mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_accumulator2);
425+
426+
if ((end - b) >= 16) {
427+
__m128i mm_a1 = _mm_loadu_si128((const __m128i *)a);
428+
__m128i mm_b1 = _mm_loadu_si128((const __m128i *)b);
429+
mm_a1 = _mm_xor_si128(mm_a1, mm_b1);
430+
mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_a1);
431+
a += 16;
432+
b += 16;
433+
}
434+
435+
mm_accumulator1 = _mm_or_si128(
436+
mm_accumulator1, _mm_unpackhi_epi64(mm_accumulator1, mm_accumulator1));
437+
mm_accumulator1 =
438+
_mm_or_si128(mm_accumulator1, _mm_srli_si128(mm_accumulator1, 4));
439+
accumulator = _mm_cvtsi128_si32(mm_accumulator1);
440+
#else
441+
uint32_t accumulator2 = 0;
442+
for (int i = 0, n = len >> 3; i < n; ++i, a += 8, b += 8) {
443+
uint32_t a_val1, b_val1;
444+
uint32_t a_val2, b_val2;
445+
memcpy(&a_val1, a, sizeof(a_val1));
446+
memcpy(&b_val1, b, sizeof(b_val1));
447+
memcpy(&a_val2, a + 4, sizeof(a_val2));
448+
memcpy(&b_val2, b + 4, sizeof(b_val2));
449+
accumulator |= a_val1 ^ b_val1;
450+
accumulator2 |= a_val2 ^ b_val2;
451+
}
452+
453+
accumulator |= accumulator2;
454+
455+
if ((end - b) >= 4) {
456+
uint32_t a_val, b_val;
457+
memcpy(&a_val, a, sizeof(a_val));
458+
memcpy(&b_val, b, sizeof(b_val));
459+
accumulator |= a_val ^ b_val;
460+
a += 4;
461+
b += 4;
462+
}
463+
#endif
464+
266465
while (b < end)
267466
accumulator |= (*a++ ^ *b++);
268467

@@ -272,9 +471,14 @@ int srtp_octet_string_is_eq(uint8_t *a, uint8_t *b, int len)
272471

273472
void srtp_cleanse(void *s, size_t len)
274473
{
474+
#if defined(__GNUC__)
475+
memset(s, 0, len);
476+
__asm__ __volatile__("" : : "r"(s) : "memory");
477+
#else
275478
volatile unsigned char *p = (volatile unsigned char *)s;
276479
while (len--)
277480
*p++ = 0;
481+
#endif
278482
}
279483

280484
void octet_string_set_to_zero(void *s, size_t len)

0 commit comments

Comments
 (0)