@@ -268,43 +268,49 @@ unsigned int rej_uniform_avx(uint32_t *r,
268
268
const uint8_t * buf ,
269
269
unsigned int buflen )
270
270
{
271
- unsigned int i , ctr , pos ;
272
- uint32_t vec [8 ] __attribute__((aligned (32 )));
273
- __m256i d , tmp ;
271
+ unsigned int ctr , pos ;
274
272
uint32_t good ;
273
+ __m256i d , tmp ;
275
274
const __m256i bound = _mm256_set1_epi32 (Q );
275
+ const __m256i mask = _mm256_set1_epi32 (0x7FFFFF );
276
+ const __m256i idx8 = _mm256_set_epi8 (-1 ,15 ,14 ,13 ,-1 ,12 ,11 ,10 ,
277
+ -1 , 9 , 8 , 7 ,-1 , 6 , 5 , 4 ,
278
+ -1 ,11 ,10 , 9 ,-1 , 8 , 7 , 6 ,
279
+ -1 , 5 , 4 , 3 ,-1 , 2 , 1 , 0 );
276
280
277
281
if (len < 8 || buflen < 24 )
278
282
return 0 ;
279
283
280
284
ctr = pos = 0 ;
281
285
while (ctr <= len - 8 && pos <= buflen - 24 ) {
282
- for (i = 0 ; i < 8 ; i ++ ) {
283
- vec [i ] = buf [pos ++ ];
284
- vec [i ] |= (uint32_t )buf [pos ++ ] << 8 ;
285
- vec [i ] |= (uint32_t )buf [pos ++ ] << 16 ;
286
- vec [i ] &= 0x7FFFFF ;
287
- }
286
+ d = _mm256_loadu_si256 ((__m256i * )& buf [pos ]);
287
+ d = _mm256_permute4x64_epi64 (d , 0x94 );
288
+ d = _mm256_shuffle_epi8 (d , idx8 );
289
+ d = _mm256_and_si256 (d , mask );
290
+ pos += 24 ;
288
291
289
- d = _mm256_load_si256 ((__m256i * )vec );
290
292
tmp = _mm256_cmpgt_epi32 (bound , d );
291
293
good = _mm256_movemask_ps ((__m256 )tmp );
292
294
__m128i rid = _mm_loadl_epi64 ((__m128i * )& idx [good ]);
293
295
tmp = _mm256_cvtepu8_epi32 (rid );
294
296
d = _mm256_permutevar8x32_epi32 (d , tmp );
297
+
295
298
_mm256_storeu_si256 ((__m256i * )& r [ctr ], d );
296
299
ctr += __builtin_popcount (good );
297
300
}
298
301
302
+ #ifndef DILITHIUM_USE_AES
303
+ uint32_t t ;
299
304
while (ctr < len && pos <= buflen - 3 ) {
300
- vec [ 0 ] = buf [pos ++ ];
301
- vec [ 0 ] |= (uint32_t )buf [pos ++ ] << 8 ;
302
- vec [ 0 ] |= (uint32_t )buf [pos ++ ] << 16 ;
303
- vec [ 0 ] &= 0x7FFFFF ;
305
+ t = buf [pos ++ ];
306
+ t |= (uint32_t )buf [pos ++ ] << 8 ;
307
+ t |= (uint32_t )buf [pos ++ ] << 16 ;
308
+ t &= 0x7FFFFF ;
304
309
305
- if (vec [ 0 ] < Q )
306
- r [ctr ++ ] = vec [ 0 ] ;
310
+ if (t < Q )
311
+ r [ctr ++ ] = t ;
307
312
}
313
+ #endif
308
314
309
315
return ctr ;
310
316
}
@@ -314,76 +320,68 @@ unsigned int rej_eta_avx(uint32_t *r,
314
320
const uint8_t * buf ,
315
321
unsigned int buflen )
316
322
{
317
- unsigned int i , ctr , pos ;
318
- uint8_t vec [32 ] __attribute__((aligned (32 )));
319
- __m256i tmp0 , tmp1 ;
320
- __m128i d0 , d1 , rid ;
323
+ unsigned int ctr , pos ;
321
324
uint32_t good ;
322
- const __m256i bound = _mm256_set1_epi8 (2 * ETA + 1 );
325
+ __m128i f0 , f1 ;
326
+ __m256i v ;
323
327
const __m256i off = _mm256_set1_epi32 (Q + ETA );
328
+ const __m128i bound = _mm_set1_epi8 (2 * ETA + 1 );
329
+ #if ETA <= 3
330
+ const __m128i mask = _mm_set1_epi8 (7 );
331
+ #else
332
+ const __m128i mask = _mm_set1_epi8 (15 );
333
+ #endif
334
+
335
+ if (len < 16 || buflen < 8 )
336
+ return 0 ;
324
337
325
338
ctr = pos = 0 ;
326
- while (ctr + 32 <= len && pos + 16 <= buflen ) {
327
- for (i = 0 ; i < 16 ; i ++ ) {
339
+ while (ctr <= len - 16 && pos <= buflen - 8 ) {
340
+ f0 = _mm_loadl_epi64 ((__m128i * )& buf [pos ]);
341
+ f0 = _mm_cvtepu8_epi16 (f0 );
328
342
#if ETA <= 3
329
- vec [2 * i + 0 ] = buf [pos ] & 0x07 ;
330
- vec [2 * i + 1 ] = buf [pos ++ ] >> 5 ;
343
+ f1 = _mm_slli_epi16 (f0 , 3 );
331
344
#else
332
- vec [2 * i + 0 ] = buf [pos ] & 0x0F ;
333
- vec [2 * i + 1 ] = buf [pos ++ ] >> 4 ;
345
+ f1 = _mm_slli_epi16 (f0 , 4 );
334
346
#endif
335
- }
347
+ f0 = _mm_or_si128 (f0 , f1 );
348
+ f0 = _mm_and_si128 (f0 , mask );
349
+ pos += 8 ;
336
350
337
- tmp0 = _mm256_load_si256 ((__m256i * )vec );
338
- tmp1 = _mm256_cmpgt_epi8 (bound , tmp0 );
339
- good = _mm256_movemask_epi8 (tmp1 );
351
+ f1 = _mm_cmpgt_epi8 (bound , f0 );
352
+ good = _mm_movemask_epi8 (f1 );
340
353
341
- d0 = _mm256_castsi256_si128 (tmp0 );
342
- rid = _mm_loadl_epi64 ((__m128i * )& idx [good & 0xFF ]);
343
- d1 = _mm_shuffle_epi8 (d0 , rid );
344
- tmp1 = _mm256_cvtepu8_epi32 (d1 );
345
- tmp1 = _mm256_sub_epi32 (off , tmp1 );
346
- _mm256_storeu_si256 ((__m256i * )& r [ctr ], tmp1 );
354
+ f1 = _mm_loadl_epi64 ((__m128i * )& idx [good & 0xFF ]);
355
+ f1 = _mm_shuffle_epi8 (f0 , f1 );
356
+ v = _mm256_cvtepu8_epi32 (f1 );
357
+ v = _mm256_sub_epi32 (off , v );
358
+ _mm256_storeu_si256 ((__m256i * )& r [ctr ], v );
347
359
ctr += __builtin_popcount (good & 0xFF );
348
360
349
- d0 = _mm_bsrli_si128 (d0 , 8 );
350
- rid = _mm_loadl_epi64 ((__m128i * )& idx [(good >> 8 ) & 0xFF ]);
351
- d1 = _mm_shuffle_epi8 (d0 , rid );
352
- tmp1 = _mm256_cvtepu8_epi32 (d1 );
353
- tmp1 = _mm256_sub_epi32 (off , tmp1 );
354
- _mm256_storeu_si256 ((__m256i * )& r [ctr ], tmp1 );
355
- ctr += __builtin_popcount ((good >> 8 ) & 0xFF );
356
-
357
- d0 = _mm256_extracti128_si256 (tmp0 , 1 );
358
- rid = _mm_loadl_epi64 ((__m128i * )& idx [(good >> 16 ) & 0xFF ]);
359
- d1 = _mm_shuffle_epi8 (d0 , rid );
360
- tmp1 = _mm256_cvtepu8_epi32 (d1 );
361
- tmp1 = _mm256_sub_epi32 (off , tmp1 );
362
- _mm256_storeu_si256 ((__m256i * )& r [ctr ], tmp1 );
363
- ctr += __builtin_popcount ((good >> 16 ) & 0xFF );
364
-
365
- d0 = _mm_bsrli_si128 (d0 , 8 );
366
- rid = _mm_loadl_epi64 ((__m128i * )& idx [(good >> 24 ) & 0xFF ]);
367
- d1 = _mm_shuffle_epi8 (d0 , rid );
368
- tmp1 = _mm256_cvtepu8_epi32 (d1 );
369
- tmp1 = _mm256_sub_epi32 (off , tmp1 );
370
- _mm256_storeu_si256 ((__m256i * )& r [ctr ], tmp1 );
371
- ctr += __builtin_popcount ((good >> 24 ) & 0xFF );
361
+ f0 = _mm_bsrli_si128 (f0 , 8 );
362
+ good >>= 8 ;
363
+ f1 = _mm_loadl_epi64 ((__m128i * )& idx [good ]);
364
+ f0 = _mm_shuffle_epi8 (f0 , f1 );
365
+ v = _mm256_cvtepu8_epi32 (f0 );
366
+ v = _mm256_sub_epi32 (off , v );
367
+ _mm256_storeu_si256 ((__m256i * )& r [ctr ], v );
368
+ ctr += __builtin_popcount (good );
372
369
}
373
370
371
+ uint32_t t0 , t1 ;
374
372
while (ctr < len && pos < buflen ) {
375
373
#if ETA <= 3
376
- vec [ 0 ] = buf [pos ] & 0x07 ;
377
- vec [ 1 ] = buf [pos ++ ] >> 5 ;
374
+ t0 = buf [pos ] & 0x07 ;
375
+ t1 = buf [pos ++ ] >> 5 ;
378
376
#else
379
- vec [ 0 ] = buf [pos ] & 0x0F ;
380
- vec [ 1 ] = buf [pos ++ ] >> 4 ;
377
+ t0 = buf [pos ] & 0x0F ;
378
+ t1 = buf [pos ++ ] >> 4 ;
381
379
#endif
382
380
383
- if (vec [ 0 ] <= 2 * ETA )
384
- r [ctr ++ ] = Q + ETA - vec [ 0 ] ;
385
- if (vec [ 1 ] <= 2 * ETA && ctr < len )
386
- r [ctr ++ ] = Q + ETA - vec [ 1 ] ;
381
+ if (t0 <= 2 * ETA )
382
+ r [ctr ++ ] = Q + ETA - t0 ;
383
+ if (t1 <= 2 * ETA && ctr < len )
384
+ r [ctr ++ ] = Q + ETA - t1 ;
387
385
}
388
386
389
387
return ctr ;
@@ -394,57 +392,61 @@ unsigned int rej_gamma1m1_avx(uint32_t *r,
394
392
const uint8_t * buf ,
395
393
unsigned int buflen )
396
394
{
397
- unsigned int i , ctr , pos ;
398
- uint32_t vec [8 ] __attribute__((aligned (32 )));
399
- __m256i d , tmp ;
395
+ unsigned int ctr , pos ;
400
396
uint32_t good ;
397
+ __m256i d , tmp ;
401
398
const __m256i bound = _mm256_set1_epi32 (2 * GAMMA1 - 1 );
402
- const __m256i off = _mm256_set1_epi32 (Q + GAMMA1 - 1 );
403
-
404
- ctr = pos = 0 ;
405
- while (ctr + 8 <= len && pos + 20 <= buflen ) {
406
- for (i = 0 ; i < 4 ; i ++ ) {
407
- vec [2 * i + 0 ] = buf [pos + 0 ];
408
- vec [2 * i + 0 ] |= (uint32_t )buf [pos + 1 ] << 8 ;
409
- vec [2 * i + 0 ] |= (uint32_t )buf [pos + 2 ] << 16 ;
410
- vec [2 * i + 0 ] &= 0xFFFFF ;
399
+ const __m256i off = _mm256_set1_epi32 (Q + GAMMA1 - 1 );
400
+ const __m256i mask = _mm256_set1_epi32 (0xFFFFF );
401
+ const __m256i srlv = _mm256_set1_epi64x (4ULL << 32 );
402
+ const __m256i idx8 = _mm256_set_epi8 (-1 ,11 ,10 , 9 ,-1 , 9 , 8 , 7 ,
403
+ -1 , 6 , 5 , 4 ,-1 , 4 , 3 , 2 ,
404
+ -1 , 9 , 8 , 7 ,-1 , 7 , 6 , 5 ,
405
+ -1 , 4 , 3 , 2 ,-1 , 2 , 1 , 0 );
411
406
412
- vec [2 * i + 1 ] = buf [pos + 2 ] >> 4 ;
413
- vec [2 * i + 1 ] |= (uint32_t )buf [pos + 3 ] << 4 ;
414
- vec [2 * i + 1 ] |= (uint32_t )buf [pos + 4 ] << 12 ;
407
+ if (len < 8 || buflen < 20 )
408
+ return 0 ;
415
409
416
- pos += 5 ;
417
- }
410
+ ctr = pos = 0 ;
411
+ while (ctr <= len - 8 && pos <= buflen - 20 ) {
412
+ d = _mm256_loadu_si256 ((__m256i * )& buf [pos ]);
413
+ d = _mm256_permute4x64_epi64 (d , 0x94 );
414
+ d = _mm256_shuffle_epi8 (d , idx8 );
415
+ d = _mm256_srlv_epi32 (d , srlv );
416
+ d = _mm256_and_si256 (d , mask );
417
+ pos += 20 ;
418
418
419
- d = _mm256_loadu_si256 ((__m256i * )vec );
420
419
tmp = _mm256_cmpgt_epi32 (bound , d );
421
420
good = _mm256_movemask_ps ((__m256 )tmp );
422
421
d = _mm256_sub_epi32 (off , d );
423
-
424
422
__m128i rid = _mm_loadl_epi64 ((__m128i * )& idx [good ]);
425
423
tmp = _mm256_cvtepu8_epi32 (rid );
426
424
d = _mm256_permutevar8x32_epi32 (d , tmp );
425
+
427
426
_mm256_storeu_si256 ((__m256i * )& r [ctr ], d );
428
427
ctr += __builtin_popcount (good );
429
428
}
430
429
431
- while (ctr < len && pos + 5 <= buflen ) {
432
- vec [0 ] = buf [pos + 0 ];
433
- vec [0 ] |= (uint32_t )buf [pos + 1 ] << 8 ;
434
- vec [0 ] |= (uint32_t )buf [pos + 2 ] << 16 ;
435
- vec [0 ] &= 0xFFFFF ;
430
+ #ifndef DILITHIUM_USE_AES
431
+ uint32_t t0 , t1 ;
432
+ while (ctr < len && pos <= buflen - 5 ) {
433
+ t0 = buf [pos ];
434
+ t0 |= (uint32_t )buf [pos + 1 ] << 8 ;
435
+ t0 |= (uint32_t )buf [pos + 2 ] << 16 ;
436
+ t0 &= 0xFFFFF ;
436
437
437
- vec [ 1 ] = buf [pos + 2 ] >> 4 ;
438
- vec [ 1 ] |= (uint32_t )buf [pos + 3 ] << 4 ;
439
- vec [ 1 ] |= (uint32_t )buf [pos + 4 ] << 12 ;
438
+ t1 = buf [pos + 2 ] >> 4 ;
439
+ t1 |= (uint32_t )buf [pos + 3 ] << 4 ;
440
+ t1 |= (uint32_t )buf [pos + 4 ] << 12 ;
440
441
441
442
pos += 5 ;
442
443
443
- if (vec [ 0 ] <= 2 * GAMMA1 - 2 )
444
- r [ctr ++ ] = Q + GAMMA1 - 1 - vec [ 0 ] ;
445
- if (vec [ 1 ] <= 2 * GAMMA1 - 2 && ctr < len )
446
- r [ctr ++ ] = Q + GAMMA1 - 1 - vec [ 1 ] ;
444
+ if (t0 <= 2 * GAMMA1 - 2 )
445
+ r [ctr ++ ] = Q + GAMMA1 - 1 - t0 ;
446
+ if (t1 <= 2 * GAMMA1 - 2 && ctr < len )
447
+ r [ctr ++ ] = Q + GAMMA1 - 1 - t1 ;
447
448
}
449
+ #endif
448
450
449
451
return ctr ;
450
452
}
0 commit comments