Skip to content

Commit

Permalink
Add phase_delay out argument
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Jun 6, 2019
1 parent 938db0b commit 83deee7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
31 changes: 26 additions & 5 deletions africanus/rime/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,40 @@

from africanus.constants import minus_two_pi_over_c
from africanus.util.docs import DocstringTemplate
from africanus.util.numba import generated_jit
from africanus.util.numba import generated_jit, njit, is_numba_type_none
from africanus.util.type_inference import infer_complex_dtype


def _out_factory(out_present):
if out_present:
def impl(out, shape, dtype):
# TODO(sjperkins) Check the dtype too?
if out.shape != shape:
raise ValueError("out.shape does not match expected shape")

return out
else:
def impl(out, shape, dtype):
return np.zeros(shape, dtype)

return njit(nogil=True, cache=True)(impl)


@generated_jit(nopython=True, nogil=True, cache=True)
def phase_delay(lm, uvw, frequency):
def phase_delay(lm, uvw, frequency, out=None):
have_out = not is_numba_type_none(out)

# Bake constants in with the correct type
one = lm.dtype(1.0)
neg_two_pi_over_c = lm.dtype(minus_two_pi_over_c)

out_dtype = infer_complex_dtype(lm, uvw, frequency)

create_output = _out_factory(have_out)

@wraps(phase_delay)
def _phase_delay_impl(lm, uvw, frequency):
def _phase_delay_impl(lm, uvw, frequency, out=None):
shape = (lm.shape[0], uvw.shape[0], frequency.shape[0])
complex_phase = np.zeros(shape, dtype=out_dtype)
complex_phase = create_output(out, shape, out_dtype)

# For each source
for source in range(lm.shape[0]):
Expand Down Expand Up @@ -87,6 +105,9 @@ def _phase_delay_impl(lm, uvw, frequency):
U, V and W components in the last dimension.
frequency : $(array_type)
frequencies of shape :code:`(chan,)`
out : $(array_type), optional
Array holding the output results. Should have the
same shape as the returned `complex_phase`.
Returns
-------
Expand Down
10 changes: 10 additions & 0 deletions africanus/rime/tests/test_rime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Tests for `codex-africanus` package."""

import numpy as np
from numpy.testing import assert_array_equal

import pytest

Expand Down Expand Up @@ -45,6 +46,15 @@ def test_phase_delay():
phase = minus_two_pi_over_c*(u*l + v*m + w*n)*freq
assert np.all(np.exp(1j*phase) == complex_phase[lm_i, uvw_i, freq_i])

# Test that we can supply an out parameter
out = np.zeros_like(complex_phase)
complex_phase_2 = phase_delay(lm, uvw, frequency, out=out)

# Result matches first version
assert_array_equal(complex_phase, complex_phase_2)
# Check that the result is in the original variable we passed in
assert out is complex_phase_2


def test_feed_rotation():
import numpy as np
Expand Down

0 comments on commit 83deee7

Please sign in to comment.