Skip to content

Commit

Permalink
Add low level functions for clouds via pybind (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
makepath-alex authored Jan 12, 2025
1 parent 6b91a40 commit 937783b
Showing 1 changed file with 230 additions and 60 deletions.
290 changes: 230 additions & 60 deletions pybind_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1701,66 +1701,6 @@ PYBIND11_MODULE(pyrte_rrtmgp, m) {
);
});

m.def("rrtmgp_compute_tau_rayleigh",
[](
int ncol,
int nlay,
int nband,
int ngpt,
int ngas,
int nflav,
int neta,
int npres,
int ntemp,
py::array_t<int> gpoint_flavor,
py::array_t<int> band_lims_gpt,
py::array_t<Float> krayl,
int idx_h2o,
py::array_t<Float> col_dry,
py::array_t<Float> col_gas,
py::array_t<Float> fminor,
py::array_t<int> jeta,
py::array_t<Bool> tropo,
py::array_t<int> jtemp,
py::array_t<Float> tau_rayleigh

) {

py::buffer_info buf_gpoint_flavor = gpoint_flavor.request();
py::buffer_info buf_band_lims_gpt = band_lims_gpt.request();
py::buffer_info buf_krayl = krayl.request();
py::buffer_info buf_col_dry = col_dry.request();
py::buffer_info buf_col_gas = col_gas.request();
py::buffer_info buf_fminor = fminor.request();
py::buffer_info buf_jeta = jeta.request();
py::buffer_info buf_tropo = tropo.request();
py::buffer_info buf_jtemp = jtemp.request();
py::buffer_info buf_tau_rayleigh = tau_rayleigh.request();

fortran::rrtmgp_compute_tau_rayleigh(

ncol,
nlay,
nband,
ngpt,
ngas,
nflav,
neta,
npres,
ntemp,
reinterpret_cast<int *> (buf_gpoint_flavor.ptr),
reinterpret_cast<int *> (buf_band_lims_gpt.ptr),
reinterpret_cast<Float *> (buf_krayl.ptr),
idx_h2o,
reinterpret_cast<Float *> (buf_col_dry.ptr),
reinterpret_cast<Float *> (buf_col_gas.ptr),
reinterpret_cast<Float *> (buf_fminor.ptr),
reinterpret_cast<int *> (buf_jeta.ptr),
reinterpret_cast<int *> (buf_tropo.ptr),
reinterpret_cast<int *> (buf_jtemp.ptr),
reinterpret_cast<Float *> (buf_tau_rayleigh.ptr)
);
});

m.def("rrtmgp_compute_Planck_source",
[](
Expand Down Expand Up @@ -1845,4 +1785,234 @@ PYBIND11_MODULE(pyrte_rrtmgp, m) {
reinterpret_cast<Float *>(buf_sfc_src_jac.ptr)
);
});

m.def("rrtmgp_compute_tau_rayleigh",
[](
int ncol,
int nlay,
int nbnd,
int ngpt,
int ngas,
int nflav,
int neta,
int npres,
int ntemp,
py::array_t<int> gpoint_flavor,
py::array_t<int> band_lims_gpt,
py::array_t<Float> krayl,
int idx_h2o,
py::array_t<Float> col_dry,
py::array_t<Float> col_gas,
py::array_t<Float> fminor,
py::array_t<int> jeta,
py::array_t<Bool> tropo,
py::array_t<int> jtemp,
py::array_t<Float> tau_rayleigh
) {
if (ncol <= 0 || nlay <= 0 || nbnd <= 0 ||
ngpt <= 0 || ngas <= 0 || nflav <= 0 ||
neta <= 0 || ntemp <= 0
) {
throw std::runtime_error("ncol, nlay, nbnd, ngpt, ngas, nflav, neta and ntemp must be positive integers");
}

if (gpoint_flavor.size() != 2 * ngpt) throw std::runtime_error("Invalid size for input array 'gpoint_flavor'");
if (band_lims_gpt.size() != 2 * nbnd) throw std::runtime_error("Invalid size for input array 'band_lims_gpt'");
if (krayl.size() != ntemp * neta * ngpt * 2) throw std::runtime_error("Invalid size for input array 'krayl'");
if (col_dry.size() != ncol * nlay) throw std::runtime_error("Invalid size for input array 'col_dry'");
if (col_gas.size() != ncol * nlay * ngas) throw std::runtime_error("Invalid size for input array 'col_gas'");
if (fminor.size() != 2 * 2 * ncol * nlay * nflav) throw std::runtime_error("Invalid size for input array 'fminor'");
if (jeta.size() != 2 * ncol * nlay * nflav) throw std::runtime_error("Invalid size for input array 'jeta'");
if (tropo.size() != ncol * nlay) throw std::runtime_error("Invalid size for input array 'tropo'");
if (jtemp.size() != ncol * nlay) throw std::runtime_error("Invalid size for input array 'jtemp'");
if (tau_rayleigh.size() != ncol * nlay * ngpt) throw std::runtime_error("Invalid size for input array 'tau_rayleigh'");

py::buffer_info buf_gpoint_flavor = gpoint_flavor.request();
py::buffer_info buf_band_lims_gpt = band_lims_gpt.request();
py::buffer_info buf_krayl = krayl.request();
py::buffer_info buf_col_dry = col_dry.request();
py::buffer_info buf_col_gas = col_gas.request();
py::buffer_info buf_fminor = fminor.request();
py::buffer_info buf_jeta = jeta.request();
py::buffer_info buf_tropo = tropo.request();
py::buffer_info buf_jtemp = jtemp.request();
py::buffer_info buf_tau_rayleigh = tau_rayleigh.request();

fortran::rrtmgp_compute_tau_rayleigh(
ncol,
nlay,
nbnd,
ngpt,
ngas,
nflav,
neta,
npres,
ntemp,
reinterpret_cast<int*> (buf_gpoint_flavor.ptr),
reinterpret_cast<int*> (buf_band_lims_gpt.ptr),
reinterpret_cast<Float*> (buf_krayl.ptr),
idx_h2o,
reinterpret_cast<Float*> (buf_col_dry.ptr),
reinterpret_cast<Float*> (buf_col_gas.ptr),
reinterpret_cast<Float*> (buf_fminor.ptr),
reinterpret_cast<int*> (buf_jeta.ptr),
reinterpret_cast<int*> (buf_tropo.ptr),
reinterpret_cast<int*> (buf_jtemp.ptr),
reinterpret_cast<Float*> (buf_tau_rayleigh.ptr)
);
});

m.def("rrtmgp_compute_cld_from_table",
[](
int ncol,
int nlay,
int nbnd,
int nsteps,
py::array_t<Bool> mask,
py::array_t<Float> lwp,
py::array_t<Float> re,
Float step_size,
Float offset,
py::array_t<Float> tau_table,
py::array_t<Float> ssa_table,
py::array_t<Float> asy_table,
py::array_t<Float> tau,
py::array_t<Float> taussa,
py::array_t<Float> taussag
) {
if (ncol <= 0 || nlay <= 0 || nbnd <= 0 || nsteps <= 0) {
throw std::runtime_error("ncol, nlay, nbnd and nsteps must be positive integers");
}

if (mask.size() != ncol * nlay) throw std::runtime_error("Invalid size for input array 'mask'");
if (lwp.size() != ncol * nlay) throw std::runtime_error("Invalid size for input array 'lwp'");
if (re.size() != ncol * nlay) throw std::runtime_error("Invalid size for input array 're'");
if (tau_table.size() != nsteps * nbnd) throw std::runtime_error("Invalid size for input array 'tau_table'");
if (ssa_table.size() != nsteps * nbnd) throw std::runtime_error("Invalid size for input array 'ssa_table'");
if (asy_table.size() != nsteps * nbnd) throw std::runtime_error("Invalid size for input array 'asy_table'");
if (tau.size() != ncol * nlay * nbnd) throw std::runtime_error("Invalid size for input array 'tau'");
if (taussa.size() != ncol * nlay * nbnd) throw std::runtime_error("Invalid size for input array 'taussa'");
if (taussag.size() != ncol * nlay * nbnd) throw std::runtime_error("Invalid size for input array 'taussag'");

py::buffer_info buf_mask = mask.request();
py::buffer_info buf_lwp = lwp.request();
py::buffer_info buf_re = re.request();
py::buffer_info buf_tau_table = tau_table.request();
py::buffer_info buf_ssa_table = ssa_table.request();
py::buffer_info buf_asy_table = asy_table.request();
py::buffer_info buf_tau = tau.request();
py::buffer_info buf_taussa = taussa.request();
py::buffer_info buf_taussag = taussag.request();

fortran::rrtmgp_compute_cld_from_table(
ncol,
nlay,
nbnd,
nsteps,
reinterpret_cast<Bool*>(buf_mask.ptr),
reinterpret_cast<Float*>(buf_lwp.ptr),
reinterpret_cast<Float*>(buf_re.ptr),
step_size,
offset,
reinterpret_cast<Float*>(buf_tau_table.ptr),
reinterpret_cast<Float*>(buf_ssa_table.ptr),
reinterpret_cast<Float*>(buf_asy_table.ptr),
reinterpret_cast<Float*>(buf_tau.ptr),
reinterpret_cast<Float*>(buf_taussa.ptr),
reinterpret_cast<Float*>(buf_taussag.ptr)
);
});

m.def("rrtmgp_compute_cld_from_pade",
[](
int ncol,
int nlay,
int nbnd,
int nsizes,
py::array_t<Bool> mask,
py::array_t<Float> lwp,
py::array_t<Float> re,
py::array_t<Float> re_bounds_ext,
py::array_t<Float> re_bounds_ssa,
py::array_t<Float> re_bounds_asy,
int m_ext,
int n_ext,
py::array_t<Float> coeffs_ext,
int m_ssa,
int n_ssa,
py::array_t<Float> coeffs_ssa,
int m_asy,
int n_asy,
py::array_t<Float> coeffs_asy,
py::array_t<Float> tau,
py::array_t<Float> taussa,
py::array_t<Float> taussag
) {
if (ncol <= 0 || nlay <= 0 || nbnd <= 0 || nsizes <= 0) {
throw std::runtime_error("ncol, nlay, nbnd and nsteps must be positive integers");
}

if (m_ext <= 0 || n_ext <= 0) {
throw std::runtime_error("m_ext and n_ext must be positive integers");
}

if (m_ssa <= 0 || n_ssa <= 0) {
throw std::runtime_error("m_ssa and n_ssa must be positive integers");
}

if (m_asy <= 0 || n_asy <= 0) {
throw std::runtime_error("m_asy and n_asy must be positive integers");
}

if (mask.size() != ncol * nlay) throw std::runtime_error("Invalid size for input array 'mask'");
if (lwp.size() != ncol * nlay) throw std::runtime_error("Invalid size for input array 'lwp'");
if (re.size() != ncol * nlay) throw std::runtime_error("Invalid size for input array 're'");
if (re_bounds_ext.size() != nsizes + 1) throw std::runtime_error("Invalid size for input array 're_bounds_ext'");
if (re_bounds_ssa.size() != nsizes + 1) throw std::runtime_error("Invalid size for input array 're_bounds_ssa'");
if (re_bounds_asy.size() != nsizes + 1) throw std::runtime_error("Invalid size for input array 're_bounds_asy'");
if (coeffs_ext.size() != nbnd * nsizes * (m_ext + n_ext)) throw std::runtime_error("Invalid size for input array 'coeffs_ext'");
if (coeffs_ssa.size() != nbnd * nsizes * (m_ssa + n_ssa)) throw std::runtime_error("Invalid size for input array 'coeffs_ssa'");
if (coeffs_asy.size() != nbnd * nsizes * (m_asy + n_asy)) throw std::runtime_error("Invalid size for input array 'coeffs_asy'");
if (tau.size() != ncol * nlay * nbnd) throw std::runtime_error("Invalid size for input array 'tau'");
if (taussa.size() != ncol * nlay * nbnd) throw std::runtime_error("Invalid size for input array 'taussa'");
if (taussag.size() != ncol * nlay * nbnd) throw std::runtime_error("Invalid size for input array 'taussag'");

py::buffer_info buf_mask = mask.request();
py::buffer_info buf_lwp = lwp.request();
py::buffer_info buf_re = re.request();
py::buffer_info buf_re_bounds_ext = re_bounds_ext.request();
py::buffer_info buf_re_bounds_ssa = re_bounds_ssa.request();
py::buffer_info buf_re_bounds_asy = re_bounds_asy.request();
py::buffer_info buf_coeffs_ext = coeffs_ext.request();
py::buffer_info buf_coeffs_ssa = coeffs_ssa.request();
py::buffer_info buf_coeffs_asy = coeffs_asy.request();
py::buffer_info buf_tau = tau.request();
py::buffer_info buf_taussa = taussa.request();
py::buffer_info buf_taussag = taussag.request();

fortran::rrtmgp_compute_cld_from_pade(
ncol,
nlay,
nbnd,
nsizes,
reinterpret_cast<Bool*>(buf_mask.ptr),
reinterpret_cast<Float*>(buf_lwp.ptr),
reinterpret_cast<Float*>(buf_re.ptr),
reinterpret_cast<Float*>(buf_re_bounds_ext.ptr),
reinterpret_cast<Float*>(buf_re_bounds_ssa.ptr),
reinterpret_cast<Float*>(buf_re_bounds_asy.ptr),
m_ext,
n_ext,
reinterpret_cast<Float*>(buf_coeffs_ext.ptr),
m_ssa,
n_ssa,
reinterpret_cast<Float*>(buf_coeffs_ssa.ptr),
m_asy,
n_asy,
reinterpret_cast<Float*>(buf_coeffs_asy.ptr),
reinterpret_cast<Float*>(buf_tau.ptr),
reinterpret_cast<Float*>(buf_taussa.ptr),
reinterpret_cast<Float*>(buf_taussag.ptr)
);
});
}

0 comments on commit 937783b

Please sign in to comment.