Skip to content

Commit

Permalink
Improve MDCT scaling to maximize accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
jmvalin committed Nov 3, 2024
1 parent 2520009 commit 88544c4
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
38 changes: 28 additions & 10 deletions celt/mdct.c
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ void clt_mdct_forward_c(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scal
/* Allows us to scale with MULT16_32_Q16(), which is faster than
MULT16_32_Q15() on ARM. */
int scale_shift = st->scale_shift-1;
int headroom;
#endif
SAVE_STACK;
(void)arch;
Expand Down Expand Up @@ -195,6 +196,9 @@ void clt_mdct_forward_c(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scal
{
kiss_fft_scalar * OPUS_RESTRICT yp = f;
const kiss_twiddle_scalar *t = &trig[0];
#ifdef FIXED_POINT
opus_val32 maxval=1;
#endif
for(i=0;i<N4;i++)
{
kiss_fft_cpx yc;
Expand All @@ -208,10 +212,21 @@ void clt_mdct_forward_c(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scal
yi = S_MUL(im,t0) + S_MUL(re,t1);
yc.r = yr;
yc.i = yi;
yc.r = PSHR32(S_MUL2(yc.r, scale), scale_shift);
yc.i = PSHR32(S_MUL2(yc.i, scale), scale_shift);
yc.r = S_MUL2(yc.r, scale);
yc.i = S_MUL2(yc.i, scale);
#ifdef FIXED_POINT
maxval = MAX32(maxval, MAX32(yc.r, yc.i));
#endif
f2[st->bitrev[i]] = yc;
}
#ifdef FIXED_POINT
headroom = IMAX(0, IMIN(scale_shift-1, 28-celt_ilog2(maxval)));
for(i=0;i<N4;i++)
{
f2[i].r = PSHR32(f2[i].r, scale_shift-headroom);
f2[i].i = PSHR32(f2[i].i, scale_shift-headroom);
}
#endif
}

/* N/4 complex FFT, does not downscale anymore */
Expand All @@ -228,8 +243,8 @@ void clt_mdct_forward_c(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scal
for(i=0;i<N4;i++)
{
kiss_fft_scalar yr, yi;
yr = S_MUL(fp->i,t[N4+i]) - S_MUL(fp->r,t[i]);
yi = S_MUL(fp->r,t[N4+i]) + S_MUL(fp->i,t[i]);
yr = PSHR32(S_MUL(fp->i,t[N4+i]) - S_MUL(fp->r,t[i]), headroom);
yi = PSHR32(S_MUL(fp->r,t[N4+i]) + S_MUL(fp->i,t[i]), headroom);
*yp1 = yr;
*yp2 = yi;
fp++;
Expand Down Expand Up @@ -272,9 +287,12 @@ void clt_mdct_backward_c(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_sca
{
int rev;
kiss_fft_scalar yr, yi;
opus_val32 x1, x2;
rev = *bitrev++;
yr = ADD32_ovflw(S_MUL(*xp2, t[i]), S_MUL(*xp1, t[N4+i]));
yi = SUB32_ovflw(S_MUL(*xp1, t[i]), S_MUL(*xp2, t[N4+i]));
x1 = SHL32(*xp1, IMDCT_HEADROOM);
x2 = SHL32(*xp2, IMDCT_HEADROOM);
yr = ADD32_ovflw(S_MUL(x2, t[i]), S_MUL(x1, t[N4+i]));
yi = SUB32_ovflw(S_MUL(x1, t[i]), S_MUL(x2, t[N4+i]));
/* We swap real and imag because we use an FFT instead of an IFFT. */
yp[2*rev+1] = yr;
yp[2*rev] = yi;
Expand Down Expand Up @@ -304,8 +322,8 @@ void clt_mdct_backward_c(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_sca
t0 = t[i];
t1 = t[N4+i];
/* We'd scale up by 2 here, but instead it's done when mixing the windows */
yr = ADD32_ovflw(S_MUL(re,t0), S_MUL(im,t1));
yi = SUB32_ovflw(S_MUL(re,t1), S_MUL(im,t0));
yr = PSHR32(ADD32_ovflw(S_MUL(re,t0), S_MUL(im,t1)), IMDCT_HEADROOM);
yi = PSHR32(SUB32_ovflw(S_MUL(re,t1), S_MUL(im,t0)), IMDCT_HEADROOM);
/* We swap real and imag because we're using an FFT instead of an IFFT. */
re = yp1[1];
im = yp1[0];
Expand All @@ -315,8 +333,8 @@ void clt_mdct_backward_c(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_sca
t0 = t[(N4-i-1)];
t1 = t[(N2-i-1)];
/* We'd scale up by 2 here, but instead it's done when mixing the windows */
yr = ADD32_ovflw(S_MUL(re,t0), S_MUL(im,t1));
yi = SUB32_ovflw(S_MUL(re,t1), S_MUL(im,t0));
yr = PSHR32(ADD32_ovflw(S_MUL(re,t0), S_MUL(im,t1)), IMDCT_HEADROOM);
yi = PSHR32(SUB32_ovflw(S_MUL(re,t1), S_MUL(im,t0)), IMDCT_HEADROOM);
yp1[0] = yr;
yp0[1] = yi;
yp0 += 2;
Expand Down
3 changes: 3 additions & 0 deletions celt/mdct.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ typedef struct {
#include "arm/mdct_arm.h"
#endif

/* There should be 2 bits of headroom in the IMDCT which we can take
advantage of to maximize accuracy. */
#define IMDCT_HEADROOM 2

int clt_mdct_init(mdct_lookup *l,int N, int maxshift, int arch);
void clt_mdct_clear(mdct_lookup *l, int arch);
Expand Down

0 comments on commit 88544c4

Please sign in to comment.