Skip to content

Commit

Permalink
Add xi zeta tensor calculation to the cwrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
AWehrhahn committed May 11, 2023
1 parent e587d70 commit f30fec8
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 1 deletion.
33 changes: 33 additions & 0 deletions pyreduce/clib/slit_func_2d_xi_zeta_bd.c
Original file line number Diff line number Diff line change
Expand Up @@ -1278,3 +1278,36 @@ int slit_func_curved(int ncols,

return 0;
}

int create_spectral_model(int ncols, int nrows, int osample, xi_ref* xi, double* spec, double* slitfunc, double* img){
int ny, pix_x, pix_y, x, iy, m;
double pix_w;

ny = (nrows + 1) * osample + 1;

for (x = 0; x < ncols; x++)
{
for (iy = 0; iy < nrows+1; iy++)
{
img[im_index(x, iy)] = 0;
}

}

for (x = 0; x < ncols; x++)
{
for (iy = 0; iy < ny; iy++)
{
for (m = 0; m < 4; m++)
{
pix_x = xi[xi_index(x, iy, m)].x;
pix_y = xi[xi_index(x, iy, m)].y;
pix_w = xi[xi_index(x, iy, m)].w;
if ((pix_x != -1) && (pix_y != -1) && (pix_w != 0)){
img[im_index(pix_x, pix_y)] += pix_w * spec[x] * slitfunc[iy];
}
}
}
}
return 0;
}
22 changes: 22 additions & 0 deletions pyreduce/clib/slit_func_2d_xi_zeta_bd.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,25 @@ int slit_func_curved(int ncols,
double *model,
double *unc,
double *info);

int xi_zeta_tensors(
int ncols,
int nrows,
int ny,
double *ycen,
int *ycen_offset,
int y_lower_lim,
int osample,
double *PSF_curve,
xi_ref *xi,
zeta_ref *zeta,
int *m_zeta);

int create_spectral_model(
int ncols,
int nrows,
int osample,
xi_ref* xi,
double* spec,
double* slitfunc,
double* img);
86 changes: 85 additions & 1 deletion pyreduce/cwrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

try:
from .clib._slitfunc_2d import lib as slitfunc_2dlib
from .clib._slitfunc_bd import ffi
from .clib._slitfunc_2d import ffi
from .clib._slitfunc_bd import lib as slitfunclib
except ImportError: # pragma: no cover
logger.error(
Expand Down Expand Up @@ -327,3 +327,87 @@ def slitfunc_curved(
mask = mask == 0

return sp, sl, model, unc, mask, info


# x, y, w
xi_ref = [("x", c_int), ("y", c_int), ("w", c_double)]
# x, iy, w
zeta_ref = [("x", c_int), ("iy", c_int), ("w", c_double)]


def xi_zeta_tensors(
ncols: int,
nrows: int,
ycen: np.ndarray,
yrange, # (int, int)
osample: int,
tilt: np.ndarray,
shear: np.ndarray,
):
ncols = int(ncols)
nrows = int(nrows)
osample = int(osample)
ny = osample * (nrows + 1) + 1

ycen_offset = ycen.astype(c_int)
ycen_int = ycen - ycen_offset
y_lower_lim = int(yrange[0])

psf_curve = np.zeros((ncols, 3), dtype=c_double)
psf_curve[:, 1] = tilt
psf_curve[:, 2] = shear

requirements = ["C", "A", "W", "O"]
ycen_int = np.require(ycen_int, dtype=c_double, requirements=requirements)
ycen_offset = np.require(ycen_offset, dtype=c_int, requirements=requirements)

xi = np.empty((ncols, ny, 4), dtype=xi_ref)
zeta = np.empty((ncols, nrows, 3 * (osample + 1)), dtype=zeta_ref)
m_zeta = np.empty((ncols, nrows), dtype=c_int)

slitfunc_2dlib.xi_zeta_tensors(
ffi.cast("int", ncols),
ffi.cast("int", nrows),
ffi.cast("int", ny),
ffi.cast("double *", ycen_int.ctypes.data),
ffi.cast("int *", ycen_offset.ctypes.data),
ffi.cast("int", y_lower_lim),
ffi.cast("int", osample),
ffi.cast("double *", psf_curve.ctypes.data),
ffi.cast("xi_ref *", xi.ctypes.data),
ffi.cast("zeta_ref *", zeta.ctypes.data),
ffi.cast("int *", m_zeta.ctypes.data),
)

return xi, zeta, m_zeta


def create_spectral_model(
ncols: int,
nrows: int,
osample: int,
xi: "xi_ref",
spec: np.ndarray,
slitfunc: np.ndarray,
):

ncols = int(ncols)
nrows = int(nrows)

requirements = ["C", "A", "W", "O"]
spec = np.require(spec, dtype=c_double, requirements=requirements)
slitfunc = np.require(slitfunc, dtype=c_double, requirements=requirements)
xi = np.require(xi, dtype=xi_ref, requirements=requirements)

img = np.empty((nrows + 1, ncols), dtype=c_double)

slitfunc_2dlib.create_spectral_model(
ffi.cast("int", ncols),
ffi.cast("int", nrows),
ffi.cast("int", osample),
ffi.cast("xi_ref *", xi.ctypes.data),
ffi.cast("double *", spec.ctypes.data),
ffi.cast("double *", slitfunc.ctypes.data),
ffi.cast("double *", img.ctypes.data),
)
return img

0 comments on commit f30fec8

Please sign in to comment.