Skip to content

Commit

Permalink
BUG: fix DCT normalisation
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore committed Jul 28, 2022
1 parent a8a2ca7 commit 9e27220
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 8 deletions.
10 changes: 5 additions & 5 deletions dctdlt.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ void dctdlt(unsigned int n, unsigned int stride_in, const double* dct,
double a, b;
unsigned int i, j;

a = 0.5;
b = 1.;
a = 1.;
b = 2.;
if(n > 0)
{
*dlt = b*dct[0];
*dlt = dct[0];
for(j = 2; j < n; j += 2)
{
b *= (j-3.)/(j+1.);
Expand Down Expand Up @@ -93,11 +93,11 @@ void dltdct(unsigned int n, unsigned int stride_in, const double* dlt,
double a, b;
unsigned int i, j;

a = 2.;
a = 1.;
b = 1.;
if(n > 0)
{
*dct = b*dlt[0];
*dct = dlt[0];
for(j = 2; j < n; j += 2)
{
b *= ((j-1.)*(j-1.))/(j*j);
Expand Down
12 changes: 10 additions & 2 deletions flt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ def dlt(a, closed=False):
dcttype = 3

# compute the DCT coefficients
b = idct(a, type=dcttype, axis=-1, norm=None)
b = idct(a, type=dcttype, axis=-1, norm='backward')

# fix last coefficient for DCT-I
if closed:
b[-1] /= 2

# memview for C interop
cdef double[::1] b_ = b
Expand Down Expand Up @@ -222,8 +226,12 @@ def idlt(b, closed=False):
# transform DLT coefficients to DCT coefficients using C function
dltdct(n, 1, &b_[0], 1, &a_[0])

# fix last coefficient for DCT-I
if closed:
a[-1] *= 2

# perform the DCT
return dct(a, type=dcttype, axis=-1, norm=None, overwrite_x=True)
return dct(a, type=dcttype, axis=-1, norm='backward', overwrite_x=True)


def dltmtx(n, closed=False):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = flt
version = 2022.7.27
version = 2022.7.27.1
maintainer = Nicolas Tessore
maintainer_email = [email protected]
description = fast Legendre transform
Expand Down
30 changes: 30 additions & 0 deletions test_flt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import pytest

from flt import dlt, idlt, theta


@pytest.mark.parametrize('n', [2, 5, 10, 11, 100, 101, 1000, 1001])
@pytest.mark.parametrize('closed', [False, True])
def test_dlt(n, closed):
t = theta(n, closed=closed)
for i in range(n):
a = np.zeros(n)
a[i] = 1

f = np.polynomial.legendre.legval(np.cos(t), a)

np.testing.assert_allclose(dlt(f, closed=closed), a, rtol=0, atol=1e-12)


@pytest.mark.parametrize('n', [2, 5, 10, 11, 100, 101, 1000, 1001])
@pytest.mark.parametrize('closed', [False, True])
def test_idlt(n, closed):
t = theta(n, closed=closed)
for i in range(n):
a = np.zeros(n)
a[i] = 1

f = np.polynomial.legendre.legval(np.cos(t), a)

np.testing.assert_allclose(idlt(a, closed=closed), f, rtol=0, atol=1e-10)

0 comments on commit 9e27220

Please sign in to comment.