Skip to content

Commit

Permalink
Return balancing weights as numpy arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
robomics committed Nov 9, 2024
1 parent b438fca commit d7ebb77
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions src/file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,24 @@ static std::vector<std::string> avail_normalizations(const hictk::File &f) {
return norms;
}

static std::vector<double> weights(const hictk::File &f, std::string_view normalization,
bool divisive) {
static auto weights(const hictk::File &f, std::string_view normalization, bool divisive) {
using WeightVector = nb::ndarray<nb::numpy, nb::shape<-1>, nb::c_contig, double>;

if (normalization == "NONE") {
return WeightVector{};
}

const auto type = divisive ? hictk::balancing::Weights::Type::DIVISIVE
: hictk::balancing::Weights::Type::MULTIPLICATIVE;
return f.normalization(normalization).to_vector(type);

// NOLINTNEXTLINE
auto *weights_ptr = new std::vector<double>(f.normalization(normalization).to_vector(type));

auto capsule = nb::capsule(weights_ptr, [](void *vect_ptr) noexcept {
delete reinterpret_cast<std::vector<double> *>(vect_ptr); // NOLINT
});

return WeightVector{weights_ptr->data(), {weights_ptr->size()}, capsule};
}

static std::filesystem::path get_path(const hictk::File &f) { return f.path(); }
Expand Down Expand Up @@ -237,7 +250,8 @@ void declare_file_class(nb::module_ &m) {
file.def("has_normalization", &hictk::File::has_normalization, nb::arg("normalization"),
"Check whether a given normalization is available.");
file.def("weights", &file::weights, nb::arg("name"), nb::arg("divisive") = true,
"Fetch the balancing weights for the given normalization method.");
"Fetch the balancing weights for the given normalization method.",
nb::rv_policy::take_ownership);
}

} // namespace hictkpy::file

0 comments on commit d7ebb77

Please sign in to comment.