Skip to content

Commit

Permalink
get N from chunk sizes when chunks_to_segments=True (#103)
Browse files Browse the repository at this point in the history
consolidate code, add comments and clarify

fix typo

fix another typo

get N from chunk sizes when chunks_to_segments=True
  • Loading branch information
dougiesquire authored Sep 10, 2020
1 parent 7bdf0b8 commit aad0d2d
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 33 deletions.
104 changes: 74 additions & 30 deletions xrft/tests/test_xrft.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,51 +548,95 @@ def test_cross_phase_2d(self, dask):
npt.assert_almost_equal(actual_phase_offset, phase_offset)


def test_parseval():
@pytest.mark.parametrize("chunks_to_segments", [False, True])
def test_parseval(chunks_to_segments):
"""Test whether the Parseval's relation is satisfied."""

N = 16
N = 16 # Must be divisible by n_segments (below)
da = xr.DataArray(np.random.rand(N,N),
dims=['x','y'], coords={'x':range(N), 'y':range(N)})
da2 = xr.DataArray(np.random.rand(N,N),
dims=['x','y'], coords={'x':range(N), 'y':range(N)})

if chunks_to_segments:
n_segments = 2
# Chunk da and da2 into n_segments
da = da.chunk({'x': N / n_segments, 'y': N / n_segments})
da2 = da2.chunk({'x': N / n_segments, 'y': N / n_segments})
else:
n_segments = 1

dim = da.dims
fftdim = [f'freq_{d}' for d in da.dims]
delta_x = []
for d in dim:
coord = da[d]
diff = np.diff(coord)
delta = diff[0]
delta_x.append(delta)
delta_xy = np.asarray(delta_x).prod() # Area of the spacings

### Test Parseval's theorem for power_spectrum with `window=False` and detrend=None
ps = xrft.power_spectrum(da,
chunks_to_segments=chunks_to_segments)
# If n_segments > 1, use xrft._stack_chunks() to stack each segment along a new dimension
da_seg = xrft.xrft._stack_chunks(da, dim).squeeze() if chunks_to_segments else da
da_prime = da_seg
# Check that the (rectangular) integral of the spectrum matches the energy
npt.assert_almost_equal((1 / delta_xy) * ps.mean(fftdim).values,
(da_prime**2).mean(dim).values,
decimal=5)

### Test Parseval's theorem for power_spectrum with `window=True` and detrend='constant'
# Note that applying a window weighting reduces the energy in a signal and we have to account
# for this reduction when testing Parseval's theorem.
ps = xrft.power_spectrum(da,
window=True,
detrend='constant',
chunks_to_segments=chunks_to_segments)
# If n_segments > 1, use xrft._stack_chunks() to stack each segment along a new dimension
da_seg = xrft.xrft._stack_chunks(da, dim).squeeze() if chunks_to_segments else da
da_prime = da_seg - da_seg.mean(dim=dim)
# Generate the window weightings for each segment
window = xr.DataArray(
np.tile(
np.hanning(N / n_segments) * np.hanning(N / n_segments)[:, np.newaxis],
(n_segments, n_segments)
),
dims=dim, coords=da.coords
)
# Check that the (rectangular) integral of the spectrum matches the windowed variance
npt.assert_almost_equal((1 / delta_xy) * ps.mean(fftdim).values,
((da_prime*window)**2).mean(dim).values,
decimal=5)

### Test Parseval's theorem for cross_spectrum with `window=True` and detrend='constant'
cs = xrft.cross_spectrum(da, da2,
window=True,
detrend='constant',
chunks_to_segments=chunks_to_segments)
# If n_segments > 1, use xrft._stack_chunks() to stack each segment along a new dimension
da2_seg = xrft.xrft._stack_chunks(da2, dim).squeeze() if chunks_to_segments else da2
da2_prime = da2_seg - da2_seg.mean(dim=dim)
# Check that the (rectangular) integral of the cross-spectrum matches the windowed co-variance
npt.assert_almost_equal((1 / delta_xy) * cs.mean(fftdim).values,
((da_prime*window) * (da2_prime*window)).mean(dim).values,
decimal=5)

### Test Parseval's theorem for a 3D case with `window=True` and `detrend='linear'`
if not chunks_to_segments:
d3d = xr.DataArray(np.random.rand(N,N,N),
dims=['time','y','x'],
coords={'time':range(N), 'y':range(N), 'x':range(N)}
).chunk({'time':1})
ps = xrft.power_spectrum(d3d,
dim=['x','y'],
window=True,
detrend='linear')
npt.assert_almost_equal((1 / delta_xy) * ps[0].values.mean(),
((numpy_detrend(d3d[0].values)*window)**2).mean(),
decimal=5)

window = np.hanning(N) * np.hanning(N)[:, np.newaxis]
ps = xrft.power_spectrum(da, window=True, detrend='constant')
da_prime = da.values - da.mean(dim=dim).values
npt.assert_almost_equal(ps.values.sum(),
(np.asarray(delta_x).prod()
* ((da_prime*window)**2).sum()
), decimal=5
)

cs = xrft.cross_spectrum(da, da2, window=True, detrend='constant')
da2_prime = da2.values - da2.mean(dim=dim).values
npt.assert_almost_equal(cs.values.sum(),
(np.asarray(delta_x).prod()
* ((da_prime*window)
* (da2_prime*window)).sum()
), decimal=5
)

d3d = xr.DataArray(np.random.rand(N,N,N),
dims=['time','y','x'],
coords={'time':range(N), 'y':range(N), 'x':range(N)}
).chunk({'time':1})
ps = xrft.power_spectrum(d3d, dim=['x','y'], window=True, detrend='linear')
npt.assert_almost_equal(ps[0].values.sum(),
(np.asarray(delta_x).prod()
* ((numpy_detrend(d3d[0].values)*window)**2).sum()
), decimal=5
)

def synthetic_field(N, dL, amp, s):
"""
Expand Down
14 changes: 11 additions & 3 deletions xrft/xrft.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,15 @@ def _diff_coord(coord):
else:
return np.diff(coord)


def _calc_normalization_factor(da, axis_num, chunks_to_segments):
"""Return the signal length, N, to be used in the normalisation of spectra"""

if chunks_to_segments:
# Use chunk sizes for normalisation
return [da.chunks[n][0] for n in axis_num]
else:
return [da.shape[n] for n in axis_num]

def dft(da, spacing_tol=1e-3, dim=None, real=None, shift=True, detrend=None,
window=False, chunks_to_segments=False, prefix='freq_'):
"""
Expand Down Expand Up @@ -453,7 +461,7 @@ def power_spectrum(da, spacing_tol=1e-3, dim=None, real=None, shift=True,
# the axes along which to take ffts
axis_num = [da.get_axis_num(d) for d in dim]

N = [da.shape[n] for n in axis_num]
N = _calc_normalization_factor(da, axis_num, chunks_to_segments)

return _power_spectrum(daft, dim, N, density)

Expand Down Expand Up @@ -535,7 +543,7 @@ def cross_spectrum(da1, da2, spacing_tol=1e-3, dim=None, shift=True,
# the axes along which to take ffts
axis_num = [da1.get_axis_num(d) for d in dim]

N = [da1.shape[n] for n in axis_num]
N = _calc_normalization_factor(da1, axis_num, chunks_to_segments)

return _cross_spectrum(daft1, daft2, dim, N, density)

Expand Down

0 comments on commit aad0d2d

Please sign in to comment.