Skip to content

Commit

Permalink
Merge pull request #174 from avik2007/cross_spectrum_quick_fix
Browse files Browse the repository at this point in the history
Readded test_xrft.py
  • Loading branch information
Takaya Uchida committed Jan 16, 2022
2 parents 672ddd1 + fce326b commit 023c86e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
20 changes: 19 additions & 1 deletion xrft/tests/test_xrft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import xarray.testing as xrt

import xrft
from ..xrft import _apply_window


@pytest.fixture()
Expand Down Expand Up @@ -524,13 +525,30 @@ def test_cross_spectrum(self, dask):
cs = xrft.cross_spectrum(
da, da2, dim=dim, shift=True, window="hann", detrend="constant"
)
test = (daft * np.conj(daft2)).values / N ** 4
test = (daft * np.conj(daft2)) / N ** 4

dk = np.diff(np.fft.fftfreq(N, 1.0))[0]
test /= dk ** 2
npt.assert_almost_equal(cs.values, test)
npt.assert_almost_equal(np.ma.masked_invalid(cs).mask.sum(), 0.0)

cs = xrft.cross_spectrum(
da,
da2,
dim=dim,
shift=True,
window="hann",
detrend="constant",
window_correction=True,
)
test = (daft * np.conj(daft2)) / N ** 4
window, _ = _apply_window(da, dim, window_type="hann")
dk = np.diff(np.fft.fftfreq(N, 1.0))[0]
test /= dk ** 2 * (window ** 2).mean()

npt.assert_almost_equal(cs.values, test)
npt.assert_almost_equal(np.ma.masked_invalid(cs).mask.sum(), 0.0)

with pytest.raises(ValueError):
xrft.cross_spectrum(da, da2, dim=dim, window=None, window_correction=True)

Expand Down
4 changes: 2 additions & 2 deletions xrft/xrft.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ def cross_spectrum(
"window_correction can only be applied when windowing is turned on."
)
else:
windows, _ = _apply_window(da, dim, window_type=kwargs.get("window"))
windows, _ = _apply_window(da1, dim, window_type=kwargs.get("window"))
cs = cs / (windows ** 2).mean()
fs = np.prod([float(cs[d].spacing) for d in updated_dims])
cs *= fs
Expand All @@ -874,7 +874,7 @@ def cross_spectrum(
"window_correction can only be applied when windowing is turned on."
)
else:
windows, _ = _apply_window(da, dim, window_type=kwargs.get("window"))
windows, _ = _apply_window(da1, dim, window_type=kwargs.get("window"))
cs = cs / windows.mean() ** 2
fs = np.prod([float(cs[d].spacing) for d in updated_dims])
cs *= fs ** 2
Expand Down

0 comments on commit 023c86e

Please sign in to comment.