Skip to content

Commit

Permalink
Adaptations for recent dtens API changes in heyoka.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Nov 12, 2023
1 parent 31a58b0 commit e130fd0
Showing 1 changed file with 58 additions and 3 deletions.
61 changes: 58 additions & 3 deletions heyoka/expose_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <heyoka/config.hpp>

#include <cassert>
#include <cstdint>
#include <functional>
#include <iterator>
Expand All @@ -20,7 +21,9 @@
#include <variant>
#include <vector>

#include <boost/iterator/transform_iterator.hpp>
#include <boost/numeric/conversion/cast.hpp>
#include <boost/safe_numerics/safe_integer.hpp>

#include <fmt/core.h>
#include <fmt/ranges.h>
Expand Down Expand Up @@ -60,6 +63,49 @@ namespace
template <typename T>
using uncvref_t = std::remove_cv_t<std::remove_reference_t<T>>;

// Functor to transform on-the-fly the content of a dtens
// from sparse format into dense format.
struct dtens_t_it {
const heyoka::dtens *dt = nullptr;

std::pair<heyoka::dtens::v_idx_t, heyoka::expression>
operator()(const std::pair<heyoka::dtens::sv_idx_t, heyoka::expression> &p) const
{
const auto &[sv_idx, ex] = p;

// Init the dense vector from the component index.
heyoka::dtens::v_idx_t ret{sv_idx.first};

// Transform the sparse index/order pairs into dense format.
// NOTE: no overflow check needed on ++idx because dtens ensures that
// the number of variables can be represented by std::uint32_t.
std::uint32_t idx = 0;
for (auto it = sv_idx.second.begin(); it != sv_idx.second.end(); ++idx) {
if (it->first == idx) {
// The current index shows up in the sparse vector,
// fetch the corresponding order and move to the next
// element of the sparse vector.
ret.push_back(it->second);
assert(it->second != 0u);
++it;
} else {
// The current index does not show up in the sparse
// vector, set the order to zero.
ret.push_back(0);
}
}

// Sanity check on the number of diff variables
// inferred from the sparse vector.
assert(ret.size() - 1u <= dt->get_nvars());

// Pad missing values at the end of ret.
ret.resize(boost::safe_numerics::safe<decltype(ret.size())>(dt->get_nvars()) + 1);

return std::make_pair(std::move(ret), ex);
}
};

} // namespace

} // namespace detail
Expand Down Expand Up @@ -492,13 +538,19 @@ void expose_expression(py::module_ &m)

const auto s_idx = boost::numeric_cast<std::iterator_traits<hey::dtens::iterator>::difference_type>(idx);

return dt.begin()[s_idx];
return detail::dtens_t_it{&dt}(dt.begin()[s_idx]);
});
dtens_cl.def("__contains__",
[](const hey::dtens &dt, const hey::dtens::v_idx_t &v_idx) { return dt.find(v_idx) != dt.end(); });
// Iterator.
dtens_cl.def(
"__iter__", [](const hey::dtens &dt) { return py::make_key_iterator(dt.begin(), dt.end()); },
"__iter__",
[](const hey::dtens &dt) {
auto t_begin = boost::iterators::make_transform_iterator(dt.begin(), detail::dtens_t_it{&dt});
auto t_end = boost::iterators::make_transform_iterator(dt.end(), detail::dtens_t_it{&dt});

return py::make_key_iterator(t_begin, t_end);
},
// NOTE: the calling dtens (argument index 1) needs to be kept alive at least until
// the return value (argument index 0) is freed by the garbage collector.
// This ensures that if we fetch an iterator and then delete the originating dtens object,
Expand All @@ -514,7 +566,10 @@ void expose_expression(py::module_ &m)
[](const hey::dtens &dt, std::uint32_t order, std::optional<std::uint32_t> component) {
const auto sr = component ? dt.get_derivatives(*component, order) : dt.get_derivatives(order);

return std::vector(sr.begin(), sr.end());
auto t_begin = boost::iterators::make_transform_iterator(sr.begin(), detail::dtens_t_it{&dt});
auto t_end = boost::iterators::make_transform_iterator(sr.end(), detail::dtens_t_it{&dt});

return std::vector(t_begin, t_end);
},
"diff_order"_a, "component"_a = py::none{});
// Gradient.
Expand Down

0 comments on commit e130fd0

Please sign in to comment.