Skip to content

Commit

Permalink
add tests and improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-richard committed Aug 3, 2023
1 parent feb5b8d commit 82ab086
Show file tree
Hide file tree
Showing 4 changed files with 361 additions and 153 deletions.
108 changes: 45 additions & 63 deletions pyrfu/pyrf/cotrans.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

# Local imports
from ..models import igrf
from .ts_tensor_xyz import ts_tensor_xyz
from .ts_vec_xyz import ts_vec_xyz
from .unix2datetime64 import unix2datetime64

__author__ = "Louis Richard"
__email__ = "[email protected]"
Expand Down Expand Up @@ -49,9 +51,8 @@ def _dipole_direction_gse(time, flag: str = "dipole"):
np.sin(np.deg2rad(phi)),
],
).T

dipole_direction_gse_ = cotrans(
np.hstack([time[:, None], dipole_direction_geo_]),
ts_vec_xyz(unix2datetime64(time), dipole_direction_geo_),
"geo>gse",
)

Expand All @@ -67,6 +68,8 @@ def _transformation_matrix(t, tind, hapgood, *args):
transf_mat_out[:, 2, 2] = np.ones(len(t))

for j, t_num in enumerate(tind[::-1]):
assert abs(t_num) in list(range(1, 6)), "t_num must be +/- 1, 2, 3, 4, 5"

if t_num in [-1, 1]:
if hapgood:
theta = 100.461 + 36000.770 * t_zero + 15.04107 * ut
Expand Down Expand Up @@ -116,8 +119,8 @@ def _transformation_matrix(t, tind, hapgood, *args):

elif t_num in [-3, 3]:
dipole_direction_gse_ = _dipole_direction_gse(t, "dipole")
y_e = dipole_direction_gse_[:, 2] # 1st col is time
z_e = dipole_direction_gse_[:, 3]
y_e = dipole_direction_gse_[:, 1] # 1st col is time
z_e = dipole_direction_gse_[:, 2]
psi = np.rad2deg(np.arctan(y_e / z_e))

transf_mat = _triang(-psi * np.sign(t_num), 0) # inverse if -3
Expand All @@ -126,24 +129,21 @@ def _transformation_matrix(t, tind, hapgood, *args):
dipole_direction_gse_ = _dipole_direction_gse(t, "dipole")

mu = np.arctan(
dipole_direction_gse_[:, 1]
/ np.sqrt(np.sum(dipole_direction_gse_[:, 2:] ** 2, axis=1)),
dipole_direction_gse_[:, 0]
/ np.sqrt(np.sum(dipole_direction_gse_[:, 1:] ** 2, axis=1)),
)
mu = np.rad2deg(mu)

transf_mat = _triang(-mu * np.sign(t_num), 1)

elif t_num in [-5, 5]:
else:
lambda_, phi = igrf(t, "dipole")

transf_mat = np.matmul(_triang(phi - 90, 1), _triang(lambda_, 2))
if t_num == -5:
transf_mat = np.transpose(transf_mat, [0, 2, 1])

else:
raise ValueError

if j == len(tind):
if j == 0:
transf_mat_out = transf_mat
else:
transf_mat_out = np.matmul(transf_mat, transf_mat_out)
Expand Down Expand Up @@ -205,27 +205,31 @@ def cotrans(inp, flag, hapgood: bool = True):
"""

assert isinstance(inp, xr.DataArray), "inp must be a xarray.DataArray"
assert inp.ndim < 3, "inp must be scalar or vector"

if ">" in flag:
ref_syst_in, ref_syst_out = flag.split(">")
else:
ref_syst_in, ref_syst_out = [None, flag.lower()]

if isinstance(inp, xr.DataArray):
if "COORDINATE_SYSTEM" in inp.attrs:
ref_syst_internal = inp.attrs["COORDINATE_SYSTEM"].lower()
ref_syst_internal = ref_syst_internal.split(">")[0]
else:
ref_syst_internal = None

if ref_syst_in is not None and ref_syst_internal is not None:
message = "input ref. frame in variable and input flag differs"
assert ref_syst_internal == ref_syst_in, message
elif ref_syst_in is None and ref_syst_internal is not None:
ref_syst_in = ref_syst_internal.lower()
elif ref_syst_in is None and ref_syst_internal is None:
raise ValueError("input reference frame undefined")
if "COORDINATE_SYSTEM" in inp.attrs:
ref_syst_internal = inp.attrs["COORDINATE_SYSTEM"].lower()
ref_syst_internal = ref_syst_internal.split(">")[0]
else:
ref_syst_internal = None

if ref_syst_in is not None and ref_syst_internal is not None:
message = "input ref. frame in variable and input flag differs"
assert ref_syst_internal == ref_syst_in, message
flag = f"{ref_syst_in}>{ref_syst_out}"
elif ref_syst_in is None and ref_syst_internal is not None:
ref_syst_in = ref_syst_internal.lower()
flag = f"{ref_syst_in}>{ref_syst_out}"
elif flag.lower() == "dipoledirectiongse":
flag = flag.lower()
elif ref_syst_in is None and ref_syst_internal is None:
raise ValueError(f"Transformation {flag} is unknown!")

if ref_syst_in == ref_syst_out:
return inp
Expand All @@ -234,24 +238,13 @@ def cotrans(inp, flag, hapgood: bool = True):
j2000 = 946727930.8160001
# j2000 = Time("J2000", format="jyear_str").unix

if isinstance(inp, xr.DataArray):
time = inp.time.data
t = (time.astype(np.int64) * 1e-9).astype(np.float64)

# Terrestial Time (seconds since J2000)
tts = t - j2000
inp_ts = inp
inp = inp.data

elif isinstance(inp, np.ndarray):
time = (inp[:, 0] * 1e9).astype("datetime64[ns]")
t = inp[:, 0]
# Terrestial Time (seconds since J2000)
tts = t - j2000
inp_ts = None
inp = inp[:, 1:]
else:
raise TypeError("invalid input")
time = inp.time.data
t = (time.astype(np.int64) * 1e-9).astype(np.float64)

# Terrestial Time (seconds since J2000)
tts = t - j2000
inp_ts = inp
inp = inp.data

if hapgood:
day_start_epoch = time.astype("datetime64[D]")
Expand Down Expand Up @@ -302,29 +295,18 @@ def cotrans(inp, flag, hapgood: bool = True):

tind = transformation_dict[flag]

elif flag == "dipoledirectiongse":
out_data = _dipole_direction_gse(t)
return ts_vec_xyz(inp.time.data, out_data)
transf_mat = _transformation_matrix(t, tind, hapgood, *args_trans_mat)

else:
raise ValueError(f"Transformation {flag} is unknown!")
if inp.ndim == 1:
out = ts_tensor_xyz(inp_ts.time.data, transf_mat)

transf_mat = _transformation_matrix(t, tind, hapgood, *args_trans_mat)

if inp.ndim == 2:
out = np.einsum("kji,ki->kj", transf_mat, inp)
elif inp.ndim == 1:
out = transf_mat
else:
raise ValueError

if inp_ts is not None:
out_data = out
out = inp_ts.copy()
out.data = out_data
out.attrs["COORDINATE_SYSTEM"] = ref_syst_out.upper()
else:
out_data = np.einsum("kji,ki->kj", transf_mat, inp)
out = inp_ts.copy()
out.data = out_data
out.attrs["COORDINATE_SYSTEM"] = ref_syst_out.upper()

else:
out = np.hstack([t[:, None], out])
out = _dipole_direction_gse(t)

return out
25 changes: 9 additions & 16 deletions pyrfu/pyrf/ebsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,17 @@ def _freq_int(freq_int, delta_b):
pc12_range, other_range = [False, False]

if isinstance(freq_int, str):
if freq_int.lower() == "pc12":
if freq_int.lower() == "pc35":
freq_int = [0.002, 0.1]

delta_t = 60 # local
else:
pc12_range = True

freq_int = [0.1, 5.0]

delta_t = 1 # local

elif freq_int.lower() == "pc35":
freq_int = [0.002, 0.1]

delta_t = 60 # local

else:
raise ValueError("Invalid format of interval")

fs_out = 1 / delta_t
else:
if freq_int[1] >= freq_int[0]:
Expand Down Expand Up @@ -443,9 +439,7 @@ def ebsp(e_xyz, db_xyz, b_xyz, b_bgd, xyz, freq_int, **kwargs):
if fac_matrix is None:
xyz = xyz[:-1, :]
else:
fac_matrix["t"] = fac_matrix["t"][:-1, :]

fac_matrix["rotMatrix"] = fac_matrix["rotMatrix"][:-1, :, :]
fac_matrix = fac_matrix[:-1, ...]

if want_ee:
e_xyz = e_xyz[:-1, :]
Expand Down Expand Up @@ -480,9 +474,6 @@ def ebsp(e_xyz, db_xyz, b_xyz, b_bgd, xyz, freq_int, **kwargs):
idx_nan_e = np.isnan(e_xyz.data)
idx_nan_eisr2 = np.isnan(eisr2.data)

if e_xyz.shape[1] < 3:
raise IndexError("E must be a 3D vector to be rotated to FAC")

if fac_matrix is None:
e_xyz = convert_fac(e_xyz, b_bgd, xyz)
else:
Expand Down Expand Up @@ -650,7 +641,9 @@ def ebsp(e_xyz, db_xyz, b_xyz, b_bgd, xyz, freq_int, **kwargs):
arg_ = ts_vec_xyz(time_b0, np.transpose(tmp))
we = convert_fac(arg_, fac_matrix)
else:
we = np.hstack([we[:, :2], we_z])
we = np.transpose(
np.vstack([np.transpose(we[:, :2]), np.transpose(we_z)])
)

power_e = 2 * np.pi * (we * np.conj(we)) / new_freq_mat
power_e = np.vstack([power_e.T, np.sum(power_e, axis=1)]).T
Expand Down
83 changes: 36 additions & 47 deletions pyrfu/pyrf/int_sph_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,11 @@ def int_sph_dist(vdf, speed, phi, theta, speed_grid, **kwargs):

# Overwrite projection dimension if azimuthal angle of projection
# plane is not provided. Set the azimuthal angle grid width.
if phi_grid is None or projection_dim == "1d":
projection_dim = "1d"
d_phi_grid = 1.0
elif phi_grid is not None and projection_dim.lower() in ["2d", "3d"]:
if phi_grid is not None and projection_dim.lower() in ["2d", "3d"]:
d_phi_grid = np.median(np.diff(phi_grid))
else:
raise RuntimeError(
"1d projection with phi_grid provided doesn't make sense!!",
)
projection_dim = "1d"
d_phi_grid = 1.0

# Make sure the transformation matrix is orthonormal.
x_phat = xyz[:, 0] / np.linalg.norm(xyz[:, 0]) # re-normalize
Expand Down Expand Up @@ -135,36 +131,9 @@ def int_sph_dist(vdf, speed, phi, theta, speed_grid, **kwargs):

n_mc_mat = n_mc_mat.astype(int)

if projection_base == "pol":
# Area or line element (primed)
d_a_grid = speed_grid ** (int(projection_dim[0]) - 1) * d_phi_grid * d_v_grid
d_a_grid = d_a_grid.astype(np.float64)

if projection_dim == "1d":
f_g = mc_pol_1d(
vdf,
speed,
phi,
theta,
d_v,
d_v_m,
d_phi,
d_theta,
speed_grid_edges,
d_a_grid,
v_lim,
a_lim,
n_mc_mat,
r_mat,
)
else:
raise NotImplementedError(
"2d projection on polar grid is not ready yet!!",
)

elif projection_base == "cart" and projection_dim == "2d":
if projection_base == "cart" and projection_dim == "2d":
d_a_grid = d_v_grid**2
f_g = mc_cart_2d(
f_g = _mc_cart_2d(
vdf,
speed,
phi,
Expand All @@ -182,7 +151,7 @@ def int_sph_dist(vdf, speed, phi, theta, speed_grid, **kwargs):
)
elif projection_base == "cart" and projection_dim == "3d":
d_a_grid = d_v_grid**3
f_g = mc_cart_3d(
f_g = _mc_cart_3d(
vdf,
speed,
phi,
Expand All @@ -199,11 +168,33 @@ def int_sph_dist(vdf, speed, phi, theta, speed_grid, **kwargs):
r_mat,
)
else:
raise ValueError("Invalid base!!")
# Area or line element (primed)
d_a_grid = speed_grid ** (int(projection_dim[0]) - 1) * d_phi_grid * d_v_grid
d_a_grid = d_a_grid.astype(np.float64)

if projection_dim == "1d":
pst = {"f": f_g, "vx": speed_grid, "vx_edges": speed_grid_edges}
elif projection_dim == "2d" and projection_base == "cart":
if projection_dim == "1d":
f_g = _mc_pol_1d(
vdf,
speed,
phi,
theta,
d_v,
d_v_m,
d_phi,
d_theta,
speed_grid_edges,
d_a_grid,
v_lim,
a_lim,
n_mc_mat,
r_mat,
)
else:
raise NotImplementedError(
"2d projection on polar grid is not ready yet!!",
)

if projection_dim == "2d" and projection_base == "cart":
pst = {
"f": f_g,
"vx": speed_grid,
Expand All @@ -222,15 +213,13 @@ def int_sph_dist(vdf, speed, phi, theta, speed_grid, **kwargs):
"vz_edges": speed_grid_edges,
}
else:
raise NotImplementedError(
"2d projection on polar grid is not ready yet!!",
)
pst = {"f": f_g, "vx": speed_grid, "vx_edges": speed_grid_edges}

return pst


@numba.jit(cache=True, nogil=True, parallel=True, nopython=True)
def mc_pol_1d(
def _mc_pol_1d(
vdf,
v,
phi,
Expand Down Expand Up @@ -341,7 +330,7 @@ def mc_pol_1d(


@numba.jit(cache=True, nogil=True, parallel=True, nopython=True)
def mc_cart_3d(
def _mc_cart_3d(
vdf,
v,
phi,
Expand Down Expand Up @@ -454,7 +443,7 @@ def mc_cart_3d(


@numba.jit(cache=True, nogil=True, parallel=True, nopython=True)
def mc_cart_2d(
def _mc_cart_2d(
vdf,
v,
phi,
Expand Down
Loading

0 comments on commit 82ab086

Please sign in to comment.