Skip to content

Commit

Permalink
fix copy and sort bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
Yikai-Liao committed Jan 12, 2024
1 parent 35d4c73 commit bc1474b
Showing 1 changed file with 22 additions and 23 deletions.
45 changes: 22 additions & 23 deletions py_src/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,22 @@ NB_MAKE_OPAQUE(vec<Quarter::unit>)
// PYBIND11_MAKE_OPAQUE(vec<Second::unit>)

template<typename T>
void sort_by_py_key(vec<T> &self, const py::object & key) {
void sort_by_py_key(vec<T> &self, const py::callable & key) {
pdqsort(self.begin(), self.end(), [&key](const T &a, const T &b) {
return key(a) < key(b);
});
}

template<typename T>
vec<T> & py_sort_inplace(vec<T> &self, const py::callable & key, const bool reverse) {
vec<T> & py_sort_inplace(vec<T> &self, const py::object & key, const bool reverse) {
if (key.is_none()) ops::sort_by_time(self);
else sort_by_py_key(self, key);
else sort_by_py_key(self, py::cast<py::callable>(key));
if (reverse) std::reverse(self.begin(), self.end());
return self;
}

template<typename T>
py::object py_sort(vec<T> &self, const py::callable & key, const bool reverse, const bool inplace = false) {
py::object py_sort(vec<T> &self, const py::object & key, const bool reverse, const bool inplace = false) {
if (inplace) {
py_sort_inplace(self, key, reverse);
return py::cast(self, py::rv_policy::reference);
Expand Down Expand Up @@ -93,6 +93,10 @@ py::object py_filter (vec<T> & self, py::callable & func, const bool inplace) {
} return py::cast(ans, py::rv_policy::move);
}

template<typename T>
vec<T> copy_vec(const vec<T> &self) {
return vec<T>(self);
}

#define NDARR(DTYPE, DIM) py::ndarray<DTYPE, py::ndim<DIM>, py::c_contig, py::device::cpu>

Expand All @@ -114,7 +118,6 @@ std::tuple<py::class_<T>, py::class_<vec<T>>> time_stamp_base(py::module_ &m, co
.def("__repr__", &T::to_string)
.def(py::self == py::self) // NOLINT
.def(py::self != py::self) // NOLINT
// .def(py::pickle( &py_to_bytes<T>, &py_from_bytes<T>));
.def("__getstate__", &py_to_bytes<T>)
.def("__setstate__", &py_from_bytes<T>);

Expand All @@ -124,7 +127,9 @@ std::tuple<py::class_<T>, py::class_<vec<T>>> time_stamp_base(py::module_ &m, co
.def("__repr__", [](const vec<T> &self) {
return fmt::format("{}", fmt::join(self, ",\n"));
})
// .def(py::pickle( &py_to_bytes<vec<T>>, &py_from_bytes<vec<T>>));
.def("__copy__" , &copy_vec<T>)
.def("__deepcopy__" , &copy_vec<T>)
.def("copy", &copy_vec<T>)
.def("__getstate__", &py_to_bytes<vec<T>>)
.def("__setstate__", &py_from_bytes<vec<T>>)
.def("filter", &py_filter<T>, py::arg("func"), py::arg("inplace")=false);
Expand Down Expand Up @@ -434,7 +439,7 @@ py::class_<symusic::TextMeta<T>> bind_text_meta_class(py::module_ &m, const std:
}

template<typename T>
Track<T>&py_sort_track_inplace(Track<T> &self, py::callable &key, const bool reverse) {
Track<T>&py_sort_track_inplace(Track<T> &self, py::object &key, const bool reverse) {
py_sort_inplace(self.notes, key, reverse);
py_sort_inplace(self.controls, key, reverse);
py_sort_inplace(self.pitch_bends, key, reverse);
Expand All @@ -443,7 +448,7 @@ Track<T>&py_sort_track_inplace(Track<T> &self, py::callable &key, const bool rev
};

template<typename T>
py::object py_sort_track(Track<T> &self, py::callable &key, const bool reverse, const bool inplace = false) {
py::object py_sort_track(Track<T> &self, py::object &key, const bool reverse, const bool inplace = false) {
if (inplace) {
py_sort_track_inplace(self, key, reverse);
return py::cast(self, py::rv_policy::reference);
Expand Down Expand Up @@ -503,12 +508,11 @@ py::class_<Track<T>> bind_track_class(py::module_ &m, const std::string & name_)
.def_rw("program", &Track<T>::program)
.def_rw("name", &Track<T>::name)
.def_prop_ro("ttype", [](const Track<T> &) { return T(); })
// .def(py::pickle( &py_to_bytes<Track<T>>, &py_from_bytes<Track<T>>))
.def("__getstate__", &py_to_bytes<Track<T>>)
.def("__setstate__", &py_from_bytes<Track<T>>)
.def(py::self == py::self) // NOLINT
.def(py::self != py::self) // NOLINT
.def("sort", &py_sort_track<T>, py::arg("key")=py::none(), py::arg("reverse")=false, py::arg("inplace")=false)
.def("sort", &py_sort_track<T>, py::arg("key")=py::none(), py::arg("reverse")=false, py::arg("inplace")=true)
.def("end", &Track<T>::end)
.def("start", &Track<T>::start)
.def("note_num", &Track<T>::note_num)
Expand All @@ -535,12 +539,12 @@ py::class_<Track<T>> bind_track_class(py::module_ &m, const std::string & name_)
.def("sort", [](vec<Track<T>> &self, const py::object & key, const bool reverse, const bool inplace) {
if (key.is_none()) throw std::invalid_argument("key must be specified");
if (inplace) {
sort_by_py_key(self, key);
sort_by_py_key(self, py::cast<py::callable>(key));
if (reverse) std::reverse(self.begin(), self.end());
return py::cast(self, py::rv_policy::reference);
} else { // copy
auto copy = vec<Track<T>>(self);
sort_by_py_key(copy, key);
sort_by_py_key(copy, py::cast<py::callable>(key));
if (reverse) std::reverse(copy.begin(), copy.end());
return py::cast(copy, py::rv_policy::move);
}
Expand All @@ -552,9 +556,11 @@ py::class_<Track<T>> bind_track_class(py::module_ &m, const std::string & name_)
}
return fmt::format("[{}]", fmt::join(strs, ", "));
})
// .def(py::pickle( &py_to_bytes<vec<Track<T>>>, &py_from_bytes<vec<Track<T>>>))
.def("__getstate__", &py_to_bytes<vec<Track<T>>>)
.def("__setstate__", &py_from_bytes<vec<Track<T>>>)
.def("copy", &copy_vec<Track<T>>)
.def("__copy__" , &copy_vec<Track<T>>)
.def("__deepcopy__" , &copy_vec<Track<T>>)
.def_prop_ro("ttype", [](const vec<Track<T>> &){ return T(); })
.def("filter", &py_filter<Track<T>>, py::arg("func"), py::arg("inplace")=false);

Expand All @@ -565,7 +571,7 @@ py::class_<Track<T>> bind_track_class(py::module_ &m, const std::string & name_)

// py sort score
template<typename T>
Score<T>& py_sort_score_inplace(Score<T> &self, py::callable& key, bool reverse) {
Score<T>& py_sort_score_inplace(Score<T> &self, py::object& key, bool reverse) {
py_sort_inplace(self.time_signatures, key, reverse);
py_sort_inplace(self.key_signatures, key, reverse);
py_sort_inplace(self.tempos, key, reverse);
Expand All @@ -577,7 +583,7 @@ Score<T>& py_sort_score_inplace(Score<T> &self, py::callable& key, bool reverse)
}

template<typename T>
py::object py_sort_score(Score<T> &self, py::callable& key, bool reverse, bool inplace = false) {
py::object py_sort_score(Score<T> &self, py::object& key, bool reverse, bool inplace = false) {
if (inplace) {
py_sort_score_inplace(self, key, reverse);
return py::cast(self, py::rv_policy::reference);
Expand Down Expand Up @@ -639,13 +645,10 @@ py::class_<Score<T>> bind_score_class(py::module_ &m, const std::string & name_)

return py::class_<Score<T>>(m, name.c_str())
.def(py::init<const i32>(), py::arg("tpq"))
// .def(py::init([](const Score<T> &other) { return other.copy(); }), "Copy constructor", py::arg("other"))
.def("__init__", [](Score<T> *self, const Score<T> &other) { new (self) Score<T>(other); }, "Copy constructor", py::arg("other"))
// .def(py::init(&midi2score<T,std::string>), "Load from midi file", py::arg("path"))
.def("__init__", [](Score<T> *self, const std::string &path) {
new (self) Score<T>(std::move(midi2score<T, std::string>(path)));
}, "Load from midi file", py::arg("path"))
// .def(py::init(&midi2score<T, std::filesystem::path>), "Load from midi file", py::arg("path"))
.def("__init__", [](Score<T> *self, const std::filesystem::path &path) {
new (self) Score<T>(std::move(midi2score<T, std::filesystem::path>(path)));
}, "Load from midi file", py::arg("path"))
Expand All @@ -672,7 +675,7 @@ py::class_<Score<T>> bind_score_class(py::module_ &m, const std::string & name_)
.def("__setstate__", &py_from_bytes<Score<T>>)
.def(py::self == py::self) // NOLINT
.def(py::self != py::self) // NOLINT
.def("sort", &py_sort_score<T>, py::arg("key")=py::none(), py::arg("reverse")=false, py::arg("inplace")=false)
.def("sort", &py_sort_score<T>, py::arg("key")=py::none(), py::arg("reverse")=false, py::arg("inplace")=true)
.def("clip", &Score<T>::clip, "Clip events a given time range", py::arg("start"), py::arg("end"), py::arg("clip_end")=false)
.def("shift_time", &py_shift_time_score<T>, py::arg("offset"), py::arg("inplace")=false)
.def("shift_pitch", &py_shift_pitch_score<T>, py::arg("offset"), py::arg("inplace")=false)
Expand Down Expand Up @@ -880,10 +883,6 @@ NB_MODULE(core, m) {
.def("__repr__", [](const Second &) { return "symsuic.core.Second"; })
.def("is_time_unit", [](const Second &) { return true; });

// bind vec for Tick::unit and Quarter::unit
py::bind_vector<vec<i32>>(m, "i32List");
py::bind_vector<vec<f32>>(m, "f32List");

// def __eq__ for all time units
tick.def("__eq__", [](const Tick &, const py::object &other) {
if (py::isinstance<Tick>(other)) return true;
Expand Down

0 comments on commit bc1474b

Please sign in to comment.