diff --git a/xrft/tests/test_xrft.py b/xrft/tests/test_xrft.py index bd1cbf06..41414b01 100644 --- a/xrft/tests/test_xrft.py +++ b/xrft/tests/test_xrft.py @@ -548,51 +548,95 @@ def test_cross_phase_2d(self, dask): npt.assert_almost_equal(actual_phase_offset, phase_offset) -def test_parseval(): +@pytest.mark.parametrize("chunks_to_segments", [False, True]) +def test_parseval(chunks_to_segments): """Test whether the Parseval's relation is satisfied.""" - N = 16 + N = 16 # Must be divisible by n_segments (below) da = xr.DataArray(np.random.rand(N,N), dims=['x','y'], coords={'x':range(N), 'y':range(N)}) da2 = xr.DataArray(np.random.rand(N,N), dims=['x','y'], coords={'x':range(N), 'y':range(N)}) + + if chunks_to_segments: + n_segments = 2 + # Chunk da and da2 into n_segments + da = da.chunk({'x': N / n_segments, 'y': N / n_segments}) + da2 = da2.chunk({'x': N / n_segments, 'y': N / n_segments}) + else: + n_segments = 1 dim = da.dims + fftdim = [f'freq_{d}' for d in da.dims] delta_x = [] for d in dim: coord = da[d] diff = np.diff(coord) delta = diff[0] delta_x.append(delta) + delta_xy = np.asarray(delta_x).prod() # Area of the spacings + + ### Test Parseval's theorem for power_spectrum with `window=False` and detrend=None + ps = xrft.power_spectrum(da, + chunks_to_segments=chunks_to_segments) + # If n_segments > 1, use xrft._stack_chunks() to stack each segment along a new dimension + da_seg = xrft.xrft._stack_chunks(da, dim).squeeze() if chunks_to_segments else da + da_prime = da_seg + # Check that the (rectangular) integral of the spectrum matches the energy + npt.assert_almost_equal((1 / delta_xy) * ps.mean(fftdim).values, + (da_prime**2).mean(dim).values, + decimal=5) + + ### Test Parseval's theorem for power_spectrum with `window=True` and detrend='constant' + # Note that applying a window weighting reduces the energy in a signal and we have to account + # for this reduction when testing Parseval's theorem. + ps = xrft.power_spectrum(da, + window=True, + detrend='constant', + chunks_to_segments=chunks_to_segments) + # If n_segments > 1, use xrft._stack_chunks() to stack each segment along a new dimension + da_seg = xrft.xrft._stack_chunks(da, dim).squeeze() if chunks_to_segments else da + da_prime = da_seg - da_seg.mean(dim=dim) + # Generate the window weightings for each segment + window = xr.DataArray( + np.tile( + np.hanning(N / n_segments) * np.hanning(N / n_segments)[:, np.newaxis], + (n_segments, n_segments) + ), + dims=dim, coords=da.coords + ) + # Check that the (rectangular) integral of the spectrum matches the windowed variance + npt.assert_almost_equal((1 / delta_xy) * ps.mean(fftdim).values, + ((da_prime*window)**2).mean(dim).values, + decimal=5) + + ### Test Parseval's theorem for cross_spectrum with `window=True` and detrend='constant' + cs = xrft.cross_spectrum(da, da2, + window=True, + detrend='constant', + chunks_to_segments=chunks_to_segments) + # If n_segments > 1, use xrft._stack_chunks() to stack each segment along a new dimension + da2_seg = xrft.xrft._stack_chunks(da2, dim).squeeze() if chunks_to_segments else da2 + da2_prime = da2_seg - da2_seg.mean(dim=dim) + # Check that the (rectangular) integral of the cross-spectrum matches the windowed co-variance + npt.assert_almost_equal((1 / delta_xy) * cs.mean(fftdim).values, + ((da_prime*window) * (da2_prime*window)).mean(dim).values, + decimal=5) + + ### Test Parseval's theorem for a 3D case with `window=True` and `detrend='linear'` + if not chunks_to_segments: + d3d = xr.DataArray(np.random.rand(N,N,N), + dims=['time','y','x'], + coords={'time':range(N), 'y':range(N), 'x':range(N)} + ).chunk({'time':1}) + ps = xrft.power_spectrum(d3d, + dim=['x','y'], + window=True, + detrend='linear') + npt.assert_almost_equal((1 / delta_xy) * ps[0].values.mean(), + ((numpy_detrend(d3d[0].values)*window)**2).mean(), + decimal=5) - window = np.hanning(N) * np.hanning(N)[:, np.newaxis] - ps = xrft.power_spectrum(da, window=True, detrend='constant') - da_prime = da.values - da.mean(dim=dim).values - npt.assert_almost_equal(ps.values.sum(), - (np.asarray(delta_x).prod() - * ((da_prime*window)**2).sum() - ), decimal=5 - ) - - cs = xrft.cross_spectrum(da, da2, window=True, detrend='constant') - da2_prime = da2.values - da2.mean(dim=dim).values - npt.assert_almost_equal(cs.values.sum(), - (np.asarray(delta_x).prod() - * ((da_prime*window) - * (da2_prime*window)).sum() - ), decimal=5 - ) - - d3d = xr.DataArray(np.random.rand(N,N,N), - dims=['time','y','x'], - coords={'time':range(N), 'y':range(N), 'x':range(N)} - ).chunk({'time':1}) - ps = xrft.power_spectrum(d3d, dim=['x','y'], window=True, detrend='linear') - npt.assert_almost_equal(ps[0].values.sum(), - (np.asarray(delta_x).prod() - * ((numpy_detrend(d3d[0].values)*window)**2).sum() - ), decimal=5 - ) def synthetic_field(N, dL, amp, s): """ diff --git a/xrft/xrft.py b/xrft/xrft.py index f19eac61..d02cefb7 100644 --- a/xrft/xrft.py +++ b/xrft/xrft.py @@ -265,7 +265,15 @@ def _diff_coord(coord): else: return np.diff(coord) - +def _calc_normalization_factor(da, axis_num, chunks_to_segments): + """Return the signal length, N, to be used in the normalisation of spectra""" + + if chunks_to_segments: + # Use chunk sizes for normalisation + return [da.chunks[n][0] for n in axis_num] + else: + return [da.shape[n] for n in axis_num] + def dft(da, spacing_tol=1e-3, dim=None, real=None, shift=True, detrend=None, window=False, chunks_to_segments=False, prefix='freq_'): """ @@ -453,7 +461,7 @@ def power_spectrum(da, spacing_tol=1e-3, dim=None, real=None, shift=True, # the axes along which to take ffts axis_num = [da.get_axis_num(d) for d in dim] - N = [da.shape[n] for n in axis_num] + N = _calc_normalization_factor(da, axis_num, chunks_to_segments) return _power_spectrum(daft, dim, N, density) @@ -535,7 +543,7 @@ def cross_spectrum(da1, da2, spacing_tol=1e-3, dim=None, shift=True, # the axes along which to take ffts axis_num = [da1.get_axis_num(d) for d in dim] - N = [da1.shape[n] for n in axis_num] + N = _calc_normalization_factor(da1, axis_num, chunks_to_segments) return _cross_spectrum(daft1, daft2, dim, N, density)