Skip to content
4 changes: 4 additions & 0 deletions BIBLIOGRAPHY.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ source code and documentation.
- [mldsa/native/x86_64/src/intt.S](mldsa/native/x86_64/src/intt.S)
- [mldsa/native/x86_64/src/ntt.S](mldsa/native/x86_64/src/ntt.S)
- [mldsa/native/x86_64/src/nttunpack.S](mldsa/native/x86_64/src/nttunpack.S)
- [mldsa/native/x86_64/src/pointwise.S](mldsa/native/x86_64/src/pointwise.S)
- [mldsa/native/x86_64/src/pointwise_acc_l4.S](mldsa/native/x86_64/src/pointwise_acc_l4.S)
- [mldsa/native/x86_64/src/pointwise_acc_l5.S](mldsa/native/x86_64/src/pointwise_acc_l5.S)
- [mldsa/native/x86_64/src/pointwise_acc_l7.S](mldsa/native/x86_64/src/pointwise_acc_l7.S)
- [mldsa/native/x86_64/src/poly_caddq_avx2.c](mldsa/native/x86_64/src/poly_caddq_avx2.c)
- [mldsa/native/x86_64/src/poly_chknorm_avx2.c](mldsa/native/x86_64/src/poly_chknorm_avx2.c)
- [mldsa/native/x86_64/src/poly_decompose_32_avx2.c](mldsa/native/x86_64/src/poly_decompose_32_avx2.c)
Expand Down
35 changes: 35 additions & 0 deletions mldsa/native/aarch64/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#define MLD_USE_NATIVE_POLY_CHKNORM
#define MLD_USE_NATIVE_POLYZ_UNPACK_17
#define MLD_USE_NATIVE_POLYZ_UNPACK_19
#define MLD_USE_NATIVE_POINTWISE_MONTGOMERY
#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4
#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5
#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7

/* Identifier for this backend so that source and assembly files
* in the build can be appropriately guarded. */
Expand Down Expand Up @@ -147,5 +151,36 @@ static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r,
mld_polyz_unpack_19_asm(r, buf, mld_polyz_unpack_19_indices);
}

static MLD_INLINE void mld_poly_pointwise_montgomery_native(
int32_t out[MLDSA_N], const int32_t in0[MLDSA_N],
const int32_t in1[MLDSA_N])
{
mld_poly_pointwise_montgomery_asm(out, in0, in1);
}

static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
int32_t w[MLDSA_N], const int32_t u[4][MLDSA_N],
const int32_t v[4][MLDSA_N])
{
mld_polyvecl_pointwise_acc_montgomery_l4_asm(w, (const int32_t *)u,
(const int32_t *)v);
}

static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
int32_t w[MLDSA_N], const int32_t u[5][MLDSA_N],
const int32_t v[5][MLDSA_N])
{
mld_polyvecl_pointwise_acc_montgomery_l5_asm(w, (const int32_t *)u,
(const int32_t *)v);
}

static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
int32_t w[MLDSA_N], const int32_t u[7][MLDSA_N],
const int32_t v[7][MLDSA_N])
{
mld_polyvecl_pointwise_acc_montgomery_l7_asm(w, (const int32_t *)u,
(const int32_t *)v);
}

#endif /* !__ASSEMBLER__ */
#endif /* !MLD_NATIVE_AARCH64_META_H */
20 changes: 20 additions & 0 deletions mldsa/native/aarch64/src/arith_native_aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,24 @@ void mld_polyz_unpack_17_asm(int32_t *r, const uint8_t *buf,
void mld_polyz_unpack_19_asm(int32_t *r, const uint8_t *buf,
const uint8_t *indices);

#define mld_poly_pointwise_montgomery_asm \
MLD_NAMESPACE(poly_pointwise_montgomery_asm)
void mld_poly_pointwise_montgomery_asm(int32_t *, const int32_t *,
const int32_t *);

#define mld_polyvecl_pointwise_acc_montgomery_l4_asm \
MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l4_asm)
void mld_polyvecl_pointwise_acc_montgomery_l4_asm(int32_t *, const int32_t *,
const int32_t *);

#define mld_polyvecl_pointwise_acc_montgomery_l5_asm \
MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l5_asm)
void mld_polyvecl_pointwise_acc_montgomery_l5_asm(int32_t *, const int32_t *,
const int32_t *);

#define mld_polyvecl_pointwise_acc_montgomery_l7_asm \
MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l7_asm)
void mld_polyvecl_pointwise_acc_montgomery_l7_asm(int32_t *, const int32_t *,
const int32_t *);

#endif /* !MLD_NATIVE_AARCH64_SRC_ARITH_NATIVE_AARCH64_H */
126 changes: 126 additions & 0 deletions mldsa/native/aarch64/src/mld_polyvecl_pointwise_acc_montgomery_l4.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/* Copyright (c) The mldsa-native project authors
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
*/

#include "../../../common.h"
#if defined(MLD_ARITH_BACKEND_AARCH64)

.macro montgomery_reduce_long res, inl, inh
uzp1 t0.4s, \inl\().4s, \inh\().4s
mul t0.4s, t0.4s, modulus_twisted.4s
smlal \inl\().2d, t0.2s, modulus.2s
smlal2 \inh\().2d, t0.4s, modulus.4s
uzp2 \res\().4s, \inl\().4s, \inh\().4s
.endm

.macro load_polys a, b, a_ptr, b_ptr
ldr q_\()\a, [\a_ptr], #16
ldr q_\()\b, [\b_ptr], #16
.endm

.macro pmull dl, dh, a, b
smull \dl\().2d, \a\().2s, \b\().2s
smull2 \dh\().2d, \a\().4s, \b\().4s
.endm

.macro pmlal dl, dh, a, b
smlal \dl\().2d, \a\().2s, \b\().2s
smlal2 \dh\().2d, \a\().4s, \b\().4s
.endm

.macro save_vregs
sub sp, sp, #(16*4)
stp d8, d9, [sp, #16*0]
stp d10, d11, [sp, #16*1]
stp d12, d13, [sp, #16*2]
stp d14, d15, [sp, #16*3]
.endm

.macro restore_vregs
ldp d8, d9, [sp, #16*0]
ldp d10, d11, [sp, #16*1]
ldp d12, d13, [sp, #16*2]
ldp d14, d15, [sp, #16*3]
add sp, sp, #(16*4)
.endm

.macro push_stack
save_vregs
.endm

.macro pop_stack
restore_vregs
.endm

out_ptr .req x0
a0_ptr .req x1
b0_ptr .req x2
a1_ptr .req x3
b1_ptr .req x4
a2_ptr .req x5
b2_ptr .req x6
a3_ptr .req x7
b3_ptr .req x8
count .req x9
wtmp .req w9

modulus .req v0
modulus_twisted .req v1

aa .req v2
bb .req v3
res .req v4
resl .req v5
resh .req v6
t0 .req v7

q_aa .req q2
q_bb .req q3
q_res .req q4

.text
.global MLD_ASM_NAMESPACE(polyvecl_pointwise_acc_montgomery_l4_asm)
.balign 4
MLD_ASM_FN_SYMBOL(polyvecl_pointwise_acc_montgomery_l4_asm)
push_stack

// load q = 8380417
movz wtmp, #57345
movk wtmp, #127, lsl #16
dup modulus.4s, wtmp

// load -q^-1 = 4236238847
movz wtmp, #57343
movk wtmp, #64639, lsl #16
dup modulus_twisted.4s, wtmp

// Computed bases of vector entries
add a1_ptr, a0_ptr, #(1 * 1024)
add a2_ptr, a0_ptr, #(2 * 1024)
add a3_ptr, a0_ptr, #(3 * 1024)

add b1_ptr, b0_ptr, #(1 * 1024)
add b2_ptr, b0_ptr, #(2 * 1024)
add b3_ptr, b0_ptr, #(3 * 1024)

mov count, #(MLDSA_N / 4)
l4_loop_start:
load_polys aa, bb, a0_ptr, b0_ptr
pmull resl, resh, aa, bb
load_polys aa, bb, a1_ptr, b1_ptr
pmlal resl, resh, aa, bb
load_polys aa, bb, a2_ptr, b2_ptr
pmlal resl, resh, aa, bb
load_polys aa, bb, a3_ptr, b3_ptr
pmlal resl, resh, aa, bb

montgomery_reduce_long res, resl, resh

str q_res, [out_ptr], #16

subs count, count, #1
cbnz count, l4_loop_start

pop_stack
ret
#endif /* MLD_ARITH_BACKEND_AARCH64 */
132 changes: 132 additions & 0 deletions mldsa/native/aarch64/src/mld_polyvecl_pointwise_acc_montgomery_l5.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/* Copyright (c) The mldsa-native project authors
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
*/

#include "../../../common.h"
#if defined(MLD_ARITH_BACKEND_AARCH64)

.macro montgomery_reduce_long res, inl, inh
uzp1 t0.4s, \inl\().4s, \inh\().4s
mul t0.4s, t0.4s, modulus_twisted.4s
smlal \inl\().2d, t0.2s, modulus.2s
smlal2 \inh\().2d, t0.4s, modulus.4s
uzp2 \res\().4s, \inl\().4s, \inh\().4s
.endm

.macro load_polys a, b, a_ptr, b_ptr
ldr q_\()\a, [\a_ptr], #16
ldr q_\()\b, [\b_ptr], #16
.endm

.macro pmull dl, dh, a, b
smull \dl\().2d, \a\().2s, \b\().2s
smull2 \dh\().2d, \a\().4s, \b\().4s
.endm

.macro pmlal dl, dh, a, b
smlal \dl\().2d, \a\().2s, \b\().2s
smlal2 \dh\().2d, \a\().4s, \b\().4s
.endm

.macro save_vregs
sub sp, sp, #(16*4)
stp d8, d9, [sp, #16*0]
stp d10, d11, [sp, #16*1]
stp d12, d13, [sp, #16*2]
stp d14, d15, [sp, #16*3]
.endm

.macro restore_vregs
ldp d8, d9, [sp, #16*0]
ldp d10, d11, [sp, #16*1]
ldp d12, d13, [sp, #16*2]
ldp d14, d15, [sp, #16*3]
add sp, sp, #(16*4)
.endm

.macro push_stack
save_vregs
.endm

.macro pop_stack
restore_vregs
.endm

out_ptr .req x0
a0_ptr .req x1
b0_ptr .req x2
a1_ptr .req x3
b1_ptr .req x4
a2_ptr .req x5
b2_ptr .req x6
a3_ptr .req x7
b3_ptr .req x8
a4_ptr .req x9
b4_ptr .req x10
count .req x11
wtmp .req w11

modulus .req v0
modulus_twisted .req v1

aa .req v2
bb .req v3
res .req v4
resl .req v5
resh .req v6
t0 .req v7

q_aa .req q2
q_bb .req q3
q_res .req q4

.text
.global MLD_ASM_NAMESPACE(polyvecl_pointwise_acc_montgomery_l5_asm)
.balign 4
MLD_ASM_FN_SYMBOL(polyvecl_pointwise_acc_montgomery_l5_asm)
push_stack

// load q = 8380417
movz wtmp, #57345
movk wtmp, #127, lsl #16
dup modulus.4s, wtmp

// load -q^-1 = 4236238847
movz wtmp, #57343
movk wtmp, #64639, lsl #16
dup modulus_twisted.4s, wtmp

// Computed bases of vector entries
add a1_ptr, a0_ptr, #(1 * 1024)
add a2_ptr, a0_ptr, #(2 * 1024)
add a3_ptr, a0_ptr, #(3 * 1024)
add a4_ptr, a0_ptr, #(4 * 1024)

add b1_ptr, b0_ptr, #(1 * 1024)
add b2_ptr, b0_ptr, #(2 * 1024)
add b3_ptr, b0_ptr, #(3 * 1024)
add b4_ptr, b0_ptr, #(4 * 1024)

mov count, #(MLDSA_N / 4)
l5_loop_start:
load_polys aa, bb, a0_ptr, b0_ptr
pmull resl, resh, aa, bb
load_polys aa, bb, a1_ptr, b1_ptr
pmlal resl, resh, aa, bb
load_polys aa, bb, a2_ptr, b2_ptr
pmlal resl, resh, aa, bb
load_polys aa, bb, a3_ptr, b3_ptr
pmlal resl, resh, aa, bb
load_polys aa, bb, a4_ptr, b4_ptr
pmlal resl, resh, aa, bb

montgomery_reduce_long res, resl, resh

str q_res, [out_ptr], #16

subs count, count, #1
cbnz count, l5_loop_start

pop_stack
ret
#endif /* MLD_ARITH_BACKEND_AARCH64 */
Loading
Loading