Skip to content

Commit 69d91bc

Browse files
committed
better rejection sampling
1 parent 4b83264 commit 69d91bc

File tree

2 files changed

+106
-101
lines changed

2 files changed

+106
-101
lines changed

avx2/poly.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,7 @@ static unsigned int rej_uniform(uint32_t *a,
390390
void poly_uniform_preinit(poly *a, stream128_state *state)
391391
{
392392
unsigned int ctr;
393-
uint8_t buf[POLY_UNIFORM_NBLOCKS*STREAM128_BLOCKBYTES]
394-
__attribute__((aligned(32)));
393+
uint8_t buf[POLY_UNIFORM_NBLOCKS*STREAM128_BLOCKBYTES];
395394

396395
stream128_squeezeblocks(buf, POLY_UNIFORM_NBLOCKS, state);
397396
ctr = rej_uniform_avx(a->coeffs, N, buf, sizeof(buf));
@@ -1046,8 +1045,12 @@ void polyz_unpack(poly * restrict r, const uint8_t * restrict a) {
10461045
**************************************************/
10471046
void polyw1_pack(uint8_t * restrict r, const poly * restrict a) {
10481047
unsigned int i;
1048+
// _mm256i vec;
10491049
DBENCH_START();
10501050

1051+
// for(i = 0; i < N/8: ++i) {
1052+
// vec = _mm256_load_si256((__m256i *)&a->coeffs[8*i]);
1053+
10511054
for(i = 0; i < N/2; ++i)
10521055
r[i] = a->coeffs[2*i+0] | (a->coeffs[2*i+1] << 4);
10531056

avx2/rejsample.c

Lines changed: 101 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -268,43 +268,49 @@ unsigned int rej_uniform_avx(uint32_t *r,
268268
const uint8_t *buf,
269269
unsigned int buflen)
270270
{
271-
unsigned int i, ctr, pos;
272-
uint32_t vec[8] __attribute__((aligned(32)));
273-
__m256i d, tmp;
271+
unsigned int ctr, pos;
274272
uint32_t good;
273+
__m256i d, tmp;
275274
const __m256i bound = _mm256_set1_epi32(Q);
275+
const __m256i mask = _mm256_set1_epi32(0x7FFFFF);
276+
const __m256i idx8 = _mm256_set_epi8(-1,15,14,13,-1,12,11,10,
277+
-1, 9, 8, 7,-1, 6, 5, 4,
278+
-1,11,10, 9,-1, 8, 7, 6,
279+
-1, 5, 4, 3,-1, 2, 1, 0);
276280

277281
if(len < 8 || buflen < 24)
278282
return 0;
279283

280284
ctr = pos = 0;
281285
while(ctr <= len - 8 && pos <= buflen - 24) {
282-
for(i = 0; i < 8; i++) {
283-
vec[i] = buf[pos++];
284-
vec[i] |= (uint32_t)buf[pos++] << 8;
285-
vec[i] |= (uint32_t)buf[pos++] << 16;
286-
vec[i] &= 0x7FFFFF;
287-
}
286+
d = _mm256_loadu_si256((__m256i *)&buf[pos]);
287+
d = _mm256_permute4x64_epi64(d, 0x94);
288+
d = _mm256_shuffle_epi8(d, idx8);
289+
d = _mm256_and_si256(d, mask);
290+
pos += 24;
288291

289-
d = _mm256_load_si256((__m256i *)vec);
290292
tmp = _mm256_cmpgt_epi32(bound, d);
291293
good = _mm256_movemask_ps((__m256)tmp);
292294
__m128i rid = _mm_loadl_epi64((__m128i *)&idx[good]);
293295
tmp = _mm256_cvtepu8_epi32(rid);
294296
d = _mm256_permutevar8x32_epi32(d, tmp);
297+
295298
_mm256_storeu_si256((__m256i *)&r[ctr], d);
296299
ctr += __builtin_popcount(good);
297300
}
298301

302+
#ifndef DILITHIUM_USE_AES
303+
uint32_t t;
299304
while(ctr < len && pos <= buflen - 3) {
300-
vec[0] = buf[pos++];
301-
vec[0] |= (uint32_t)buf[pos++] << 8;
302-
vec[0] |= (uint32_t)buf[pos++] << 16;
303-
vec[0] &= 0x7FFFFF;
305+
t = buf[pos++];
306+
t |= (uint32_t)buf[pos++] << 8;
307+
t |= (uint32_t)buf[pos++] << 16;
308+
t &= 0x7FFFFF;
304309

305-
if(vec[0] < Q)
306-
r[ctr++] = vec[0];
310+
if(t < Q)
311+
r[ctr++] = t;
307312
}
313+
#endif
308314

309315
return ctr;
310316
}
@@ -314,76 +320,68 @@ unsigned int rej_eta_avx(uint32_t *r,
314320
const uint8_t *buf,
315321
unsigned int buflen)
316322
{
317-
unsigned int i, ctr, pos;
318-
uint8_t vec[32] __attribute__((aligned(32)));
319-
__m256i tmp0, tmp1;
320-
__m128i d0, d1, rid;
323+
unsigned int ctr, pos;
321324
uint32_t good;
322-
const __m256i bound = _mm256_set1_epi8(2*ETA + 1);
325+
__m128i f0, f1;
326+
__m256i v;
323327
const __m256i off = _mm256_set1_epi32(Q + ETA);
328+
const __m128i bound = _mm_set1_epi8(2*ETA + 1);
329+
#if ETA <= 3
330+
const __m128i mask = _mm_set1_epi8(7);
331+
#else
332+
const __m128i mask = _mm_set1_epi8(15);
333+
#endif
334+
335+
if(len < 16 || buflen < 8)
336+
return 0;
324337

325338
ctr = pos = 0;
326-
while(ctr + 32 <= len && pos + 16 <= buflen) {
327-
for(i = 0; i < 16; i++) {
339+
while(ctr <= len - 16 && pos <= buflen - 8) {
340+
f0 = _mm_loadl_epi64((__m128i *)&buf[pos]);
341+
f0 = _mm_cvtepu8_epi16(f0);
328342
#if ETA <= 3
329-
vec[2*i+0] = buf[pos] & 0x07;
330-
vec[2*i+1] = buf[pos++] >> 5;
343+
f1 = _mm_slli_epi16(f0, 3);
331344
#else
332-
vec[2*i+0] = buf[pos] & 0x0F;
333-
vec[2*i+1] = buf[pos++] >> 4;
345+
f1 = _mm_slli_epi16(f0, 4);
334346
#endif
335-
}
347+
f0 = _mm_or_si128(f0, f1);
348+
f0 = _mm_and_si128(f0, mask);
349+
pos += 8;
336350

337-
tmp0 = _mm256_load_si256((__m256i *)vec);
338-
tmp1 = _mm256_cmpgt_epi8(bound, tmp0);
339-
good = _mm256_movemask_epi8(tmp1);
351+
f1 = _mm_cmpgt_epi8(bound, f0);
352+
good = _mm_movemask_epi8(f1);
340353

341-
d0 = _mm256_castsi256_si128(tmp0);
342-
rid = _mm_loadl_epi64((__m128i *)&idx[good & 0xFF]);
343-
d1 = _mm_shuffle_epi8(d0, rid);
344-
tmp1 = _mm256_cvtepu8_epi32(d1);
345-
tmp1 = _mm256_sub_epi32(off, tmp1);
346-
_mm256_storeu_si256((__m256i *)&r[ctr], tmp1);
354+
f1 = _mm_loadl_epi64((__m128i *)&idx[good & 0xFF]);
355+
f1 = _mm_shuffle_epi8(f0, f1);
356+
v = _mm256_cvtepu8_epi32(f1);
357+
v = _mm256_sub_epi32(off, v);
358+
_mm256_storeu_si256((__m256i *)&r[ctr], v);
347359
ctr += __builtin_popcount(good & 0xFF);
348360

349-
d0 = _mm_bsrli_si128(d0, 8);
350-
rid = _mm_loadl_epi64((__m128i *)&idx[(good >> 8) & 0xFF]);
351-
d1 = _mm_shuffle_epi8(d0, rid);
352-
tmp1 = _mm256_cvtepu8_epi32(d1);
353-
tmp1 = _mm256_sub_epi32(off, tmp1);
354-
_mm256_storeu_si256((__m256i *)&r[ctr], tmp1);
355-
ctr += __builtin_popcount((good >> 8) & 0xFF);
356-
357-
d0 = _mm256_extracti128_si256(tmp0, 1);
358-
rid = _mm_loadl_epi64((__m128i *)&idx[(good >> 16) & 0xFF]);
359-
d1 = _mm_shuffle_epi8(d0, rid);
360-
tmp1 = _mm256_cvtepu8_epi32(d1);
361-
tmp1 = _mm256_sub_epi32(off, tmp1);
362-
_mm256_storeu_si256((__m256i *)&r[ctr], tmp1);
363-
ctr += __builtin_popcount((good >> 16) & 0xFF);
364-
365-
d0 = _mm_bsrli_si128(d0, 8);
366-
rid = _mm_loadl_epi64((__m128i *)&idx[(good >> 24) & 0xFF]);
367-
d1 = _mm_shuffle_epi8(d0, rid);
368-
tmp1 = _mm256_cvtepu8_epi32(d1);
369-
tmp1 = _mm256_sub_epi32(off, tmp1);
370-
_mm256_storeu_si256((__m256i *)&r[ctr], tmp1);
371-
ctr += __builtin_popcount((good >> 24) & 0xFF);
361+
f0 = _mm_bsrli_si128(f0, 8);
362+
good >>= 8;
363+
f1 = _mm_loadl_epi64((__m128i *)&idx[good]);
364+
f0 = _mm_shuffle_epi8(f0, f1);
365+
v = _mm256_cvtepu8_epi32(f0);
366+
v = _mm256_sub_epi32(off, v);
367+
_mm256_storeu_si256((__m256i *)&r[ctr], v);
368+
ctr += __builtin_popcount(good);
372369
}
373370

371+
uint32_t t0, t1;
374372
while(ctr < len && pos < buflen) {
375373
#if ETA <= 3
376-
vec[0] = buf[pos] & 0x07;
377-
vec[1] = buf[pos++] >> 5;
374+
t0 = buf[pos] & 0x07;
375+
t1 = buf[pos++] >> 5;
378376
#else
379-
vec[0] = buf[pos] & 0x0F;
380-
vec[1] = buf[pos++] >> 4;
377+
t0 = buf[pos] & 0x0F;
378+
t1 = buf[pos++] >> 4;
381379
#endif
382380

383-
if(vec[0] <= 2*ETA)
384-
r[ctr++] = Q + ETA - vec[0];
385-
if(vec[1] <= 2*ETA && ctr < len)
386-
r[ctr++] = Q + ETA - vec[1];
381+
if(t0 <= 2*ETA)
382+
r[ctr++] = Q + ETA - t0;
383+
if(t1 <= 2*ETA && ctr < len)
384+
r[ctr++] = Q + ETA - t1;
387385
}
388386

389387
return ctr;
@@ -394,57 +392,61 @@ unsigned int rej_gamma1m1_avx(uint32_t *r,
394392
const uint8_t *buf,
395393
unsigned int buflen)
396394
{
397-
unsigned int i, ctr, pos;
398-
uint32_t vec[8] __attribute__((aligned(32)));
399-
__m256i d, tmp;
395+
unsigned int ctr, pos;
400396
uint32_t good;
397+
__m256i d, tmp;
401398
const __m256i bound = _mm256_set1_epi32(2*GAMMA1 - 1);
402-
const __m256i off = _mm256_set1_epi32(Q + GAMMA1 - 1);
403-
404-
ctr = pos = 0;
405-
while(ctr + 8 <= len && pos + 20 <= buflen) {
406-
for(i = 0; i < 4; i++) {
407-
vec[2*i+0] = buf[pos + 0];
408-
vec[2*i+0] |= (uint32_t)buf[pos + 1] << 8;
409-
vec[2*i+0] |= (uint32_t)buf[pos + 2] << 16;
410-
vec[2*i+0] &= 0xFFFFF;
399+
const __m256i off = _mm256_set1_epi32(Q + GAMMA1 - 1);
400+
const __m256i mask = _mm256_set1_epi32(0xFFFFF);
401+
const __m256i srlv = _mm256_set1_epi64x(4ULL << 32);
402+
const __m256i idx8 = _mm256_set_epi8(-1,11,10, 9,-1, 9, 8, 7,
403+
-1, 6, 5, 4,-1, 4, 3, 2,
404+
-1, 9, 8, 7,-1, 7, 6, 5,
405+
-1, 4, 3, 2,-1, 2, 1, 0);
411406

412-
vec[2*i+1] = buf[pos + 2] >> 4;
413-
vec[2*i+1] |= (uint32_t)buf[pos + 3] << 4;
414-
vec[2*i+1] |= (uint32_t)buf[pos + 4] << 12;
407+
if(len < 8 || buflen < 20)
408+
return 0;
415409

416-
pos += 5;
417-
}
410+
ctr = pos = 0;
411+
while(ctr <= len - 8 && pos <= buflen - 20) {
412+
d = _mm256_loadu_si256((__m256i *)&buf[pos]);
413+
d = _mm256_permute4x64_epi64(d, 0x94);
414+
d = _mm256_shuffle_epi8(d, idx8);
415+
d = _mm256_srlv_epi32(d, srlv);
416+
d = _mm256_and_si256(d, mask);
417+
pos += 20;
418418

419-
d = _mm256_loadu_si256((__m256i *)vec);
420419
tmp = _mm256_cmpgt_epi32(bound, d);
421420
good = _mm256_movemask_ps((__m256)tmp);
422421
d = _mm256_sub_epi32(off, d);
423-
424422
__m128i rid = _mm_loadl_epi64((__m128i *)&idx[good]);
425423
tmp = _mm256_cvtepu8_epi32(rid);
426424
d = _mm256_permutevar8x32_epi32(d, tmp);
425+
427426
_mm256_storeu_si256((__m256i *)&r[ctr], d);
428427
ctr += __builtin_popcount(good);
429428
}
430429

431-
while(ctr < len && pos + 5 <= buflen) {
432-
vec[0] = buf[pos + 0];
433-
vec[0] |= (uint32_t)buf[pos + 1] << 8;
434-
vec[0] |= (uint32_t)buf[pos + 2] << 16;
435-
vec[0] &= 0xFFFFF;
430+
#ifndef DILITHIUM_USE_AES
431+
uint32_t t0, t1;
432+
while(ctr < len && pos <= buflen - 5) {
433+
t0 = buf[pos];
434+
t0 |= (uint32_t)buf[pos + 1] << 8;
435+
t0 |= (uint32_t)buf[pos + 2] << 16;
436+
t0 &= 0xFFFFF;
436437

437-
vec[1] = buf[pos + 2] >> 4;
438-
vec[1] |= (uint32_t)buf[pos + 3] << 4;
439-
vec[1] |= (uint32_t)buf[pos + 4] << 12;
438+
t1 = buf[pos + 2] >> 4;
439+
t1 |= (uint32_t)buf[pos + 3] << 4;
440+
t1 |= (uint32_t)buf[pos + 4] << 12;
440441

441442
pos += 5;
442443

443-
if(vec[0] <= 2*GAMMA1 - 2)
444-
r[ctr++] = Q + GAMMA1 - 1 - vec[0];
445-
if(vec[1] <= 2*GAMMA1 - 2 && ctr < len)
446-
r[ctr++] = Q + GAMMA1 - 1 - vec[1];
444+
if(t0 <= 2*GAMMA1 - 2)
445+
r[ctr++] = Q + GAMMA1 - 1 - t0;
446+
if(t1 <= 2*GAMMA1 - 2 && ctr < len)
447+
r[ctr++] = Q + GAMMA1 - 1 - t1;
447448
}
449+
#endif
448450

449451
return ctr;
450452
}

0 commit comments

Comments
 (0)