Skip to content

Commit

Permalink
Merge pull request #531 from JBorrow/fix_fft_units
Browse files Browse the repository at this point in the history
Change FFT return units to be same as input
  • Loading branch information
neutrinoceros authored Oct 25, 2024
2 parents 55f1ac4 + 24b0e1d commit e04fee7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
28 changes: 14 additions & 14 deletions unyt/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,72 +358,72 @@ def block(arrays):

@implements(np.fft.fft)
def ftt_fft(a, *args, **kwargs):
return np.fft.fft._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.fft._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.fft2)
def ftt_fft2(a, *args, **kwargs):
return np.fft.fft2._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.fft2._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.fftn)
def ftt_fftn(a, *args, **kwargs):
return np.fft.fftn._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.fftn._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.hfft)
def ftt_hfft(a, *args, **kwargs):
return np.fft.hfft._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.hfft._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.rfft)
def ftt_rfft(a, *args, **kwargs):
return np.fft.rfft._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.rfft._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.rfft2)
def ftt_rfft2(a, *args, **kwargs):
return np.fft.rfft2._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.rfft2._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.rfftn)
def ftt_rfftn(a, *args, **kwargs):
return np.fft.rfftn._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.rfftn._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.ifft)
def ftt_ifft(a, *args, **kwargs):
return np.fft.ifft._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.ifft._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.ifft2)
def ftt_ifft2(a, *args, **kwargs):
return np.fft.ifft2._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.ifft2._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.ifftn)
def ftt_ifftn(a, *args, **kwargs):
return np.fft.ifftn._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.ifftn._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.ihfft)
def ftt_ihfft(a, *args, **kwargs):
return np.fft.ihfft._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.ihfft._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.irfft)
def ftt_irfft(a, *args, **kwargs):
return np.fft.irfft._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.irfft._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.irfft2)
def ftt_irfft2(a, *args, **kwargs):
return np.fft.irfft2._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.irfft2._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.irfftn)
def ftt_irfftn(a, *args, **kwargs):
return np.fft.irfftn._implementation(np.asarray(a), *args, **kwargs) / a.units
return np.fft.irfftn._implementation(np.asarray(a), *args, **kwargs) * a.units


@implements(np.fft.fftshift)
Expand Down
4 changes: 2 additions & 2 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ def test_fft_1D(func):
x1 = [0, 1, 2] * cm
res = func(x1)
assert type(res) is unyt_array
assert res.units == (1 / cm).units
assert res.units == (1 * cm).units


@pytest.mark.parametrize(
Expand All @@ -943,7 +943,7 @@ def test_fft_ND(func):
x1 = [[0, 1, 2], [0, 1, 2], [0, 1, 2]] * cm
res = func(x1)
assert type(res) is unyt_array
assert res.units == (1 / cm).units
assert res.units == (1 * cm).units


@pytest.mark.parametrize("func", [np.fft.fftshift, np.fft.ifftshift])
Expand Down

0 comments on commit e04fee7

Please sign in to comment.