Skip to content

Commit

Permalink
Merge pull request #163 from bluescarni/pr/heyoka_updates
Browse files Browse the repository at this point in the history
Deal with API changes in heyoka
  • Loading branch information
bluescarni authored Jan 17, 2024
2 parents 30516ee + a86c7a1 commit 2816f23
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 39 deletions.
19 changes: 7 additions & 12 deletions heyoka/expose_batch_integrators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ void expose_batch_integrator_impl(py::module_ &m, const std::string &suffix)
using t_ev_t = hey::t_event_batch<T>;
using nt_ev_t = hey::nt_event_batch<T>;

using sys_t = std::vector<std::pair<hey::expression, hey::expression>>;

// Implementation of the ctor.
auto tab_ctor_impl = [](const auto &sys, const py::iterable &state_ob, std::optional<py::iterable> time_ob,
auto tab_ctor_impl = [](const sys_t &sys, const py::iterable &state_ob, std::optional<py::iterable> time_ob,
std::optional<py::iterable> pars_ob, T tol, bool high_accuracy, bool compact_mode,
std::vector<t_ev_t> tes, std::vector<nt_ev_t> ntes, bool parallel_mode, unsigned opt_level,
bool force_avx512, bool slp_vectorize, bool fast_math) {
Expand Down Expand Up @@ -175,21 +177,14 @@ void expose_batch_integrator_impl(py::module_ &m, const std::string &suffix)
py::class_<hey::taylor_adaptive_batch<T>> tab_c(m, fmt::format("taylor_adaptive_batch_{}", suffix).c_str(),
py::dynamic_attr{});

using variant_t
= std::variant<std::vector<std::pair<hey::expression, hey::expression>>, std::vector<hey::expression>>;

tab_c
.def(py::init([tab_ctor_impl](const variant_t &sys, const py::iterable &state, std::optional<py::iterable> time,
.def(py::init([tab_ctor_impl](const sys_t &sys, const py::iterable &state, std::optional<py::iterable> time,
std::optional<py::iterable> pars, T tol, bool high_accuracy, bool compact_mode,
std::vector<t_ev_t> tes, std::vector<nt_ev_t> ntes, bool parallel_mode,
unsigned opt_level, bool force_avx512, bool slp_vectorize, bool fast_math) {
return std::visit(
[&](const auto &value) {
return tab_ctor_impl(value, state, std::move(time), std::move(pars), tol, high_accuracy,
compact_mode, std::move(tes), std::move(ntes), parallel_mode, opt_level,
force_avx512, slp_vectorize, fast_math);
},
sys);
return tab_ctor_impl(sys, state, std::move(time), std::move(pars), tol, high_accuracy, compact_mode,
std::move(tes), std::move(ntes), parallel_mode, opt_level, force_avx512,
slp_vectorize, fast_math);
}),
"sys"_a, "state"_a, "time"_a = py::none{}, "pars"_a = py::none{}, "tol"_a.noconvert() = static_cast<T>(0),
"high_accuracy"_a = false, "compact_mode"_a = false, "t_events"_a = py::list{}, "nt_events"_a = py::list{},
Expand Down
49 changes: 22 additions & 27 deletions heyoka/taylor_expose_integrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,38 +87,33 @@ void expose_taylor_integrator_impl(py::module &m, const std::string &suffix)
using t_ev_t = hey::t_event<T>;
using nt_ev_t = hey::nt_event<T>;

// Union of ODE system types, used in the ctor.
using sys_t = std::variant<std::vector<std::pair<hey::expression, hey::expression>>, std::vector<hey::expression>>;
using sys_t = std::vector<std::pair<hey::expression, hey::expression>>;

py::class_<hey::taylor_adaptive<T>> cl(m, (fmt::format("taylor_adaptive_{}", suffix)).c_str(), py::dynamic_attr{});
cl.def(py::init([](const sys_t &sys, std::vector<T> state, T time, std::vector<T> pars, T tol, bool high_accuracy,
bool compact_mode, std::vector<t_ev_t> tes, std::vector<nt_ev_t> ntes, bool parallel_mode,
unsigned opt_level, bool force_avx512, bool slp_vectorize, bool fast_math, long long prec) {
return std::visit(
[&](const auto &val) {
// NOTE: GIL release is fine here even if the events contain
// Python objects, as the event vectors are moved in
// upon construction and thus we should never end up calling
// into the interpreter.
py::gil_scoped_release release;

return hey::taylor_adaptive<T>{val,
std::move(state),
kw::time = time,
kw::tol = tol,
kw::high_accuracy = high_accuracy,
kw::compact_mode = compact_mode,
kw::pars = std::move(pars),
kw::t_events = std::move(tes),
kw::nt_events = std::move(ntes),
kw::parallel_mode = parallel_mode,
kw::opt_level = opt_level,
kw::force_avx512 = force_avx512,
kw::slp_vectorize = slp_vectorize,
kw::fast_math = fast_math,
kw::prec = prec};
},
sys);
// NOTE: GIL release is fine here even if the events contain
// Python objects, as the event vectors are moved in
// upon construction and thus we should never end up calling
// into the interpreter.
py::gil_scoped_release release;

return hey::taylor_adaptive<T>{sys,
std::move(state),
kw::time = time,
kw::tol = tol,
kw::high_accuracy = high_accuracy,
kw::compact_mode = compact_mode,
kw::pars = std::move(pars),
kw::t_events = std::move(tes),
kw::nt_events = std::move(ntes),
kw::parallel_mode = parallel_mode,
kw::opt_level = opt_level,
kw::force_avx512 = force_avx512,
kw::slp_vectorize = slp_vectorize,
kw::fast_math = fast_math,
kw::prec = prec};
}),
"sys"_a, "state"_a.noconvert(), "time"_a.noconvert() = static_cast<T>(0), "pars"_a.noconvert() = py::list{},
"tol"_a.noconvert() = static_cast<T>(0), "high_accuracy"_a = false, "compact_mode"_a = default_cm<T>,
Expand Down

0 comments on commit 2816f23

Please sign in to comment.