Skip to content

Commit

Permalink
Merge pull request #141 from bluescarni/pr/relu
Browse files Browse the repository at this point in the history
Expose relu and relup
  • Loading branch information
bluescarni authored Nov 5, 2023
2 parents 47b2cf4 + 33f3119 commit 1854ffb
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 17 deletions.
6 changes: 6 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ Changelog
New
~~~

- Add a model for feed-forward neural networks
(`#142 <https://github.com/bluescarni/heyoka.py/pull/142>`__).
- Implement (leaky) ``ReLU`` and its derivative in the expression
system (`#141 <https://github.com/bluescarni/heyoka.py/pull/141>`__).
- Implement the eccentric longitude :math:`F` in the expression
system (`#140 <https://github.com/bluescarni/heyoka.py/pull/140>`__).
- Implement the delta eccentric anomaly :math:`\Delta E` in the expression
Expand All @@ -19,6 +23,8 @@ New
(`#140 <https://github.com/bluescarni/heyoka.py/pull/140>`__).
- New example notebook implementing Lagrange propagation
(`#140 <https://github.com/bluescarni/heyoka.py/pull/140>`__).
- New example notebook on the continuation of periodic orbits
in the CR3BP (`#97 <https://github.com/bluescarni/heyoka.py/pull/97>`__).

Changes
~~~~~~~
Expand Down
15 changes: 15 additions & 0 deletions heyoka/_test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,18 @@ def test_fix_unfix(self):
self.assertEqual(fix_nn(expression(1.1)), expression(1.1))
self.assertEqual(unfix(fix(x + y)), x + y)
self.assertEqual(unfix(arg=[fix(x + fix(y)), fix(x - fix(y))]), [x + y, x - y])

def test_relu_wrappers(self):
from . import make_vars, leaky_relu, leaky_relup, relu, relup

x, y = make_vars("x", "y")

self.assertEqual(leaky_relu(0.)(x), relu(x))
self.assertEqual(leaky_relup(0.)(x), relup(x))
self.assertEqual(leaky_relu(0.1)(x+y), relu(x+y,0.1))
self.assertEqual(leaky_relup(0.1)(x+y), relup(x+y,0.1))

self.assertEqual(leaky_relu(0.)(x*y), relu(x*y))
self.assertEqual(leaky_relup(0.)(x*y), relup(x*y))
self.assertEqual(leaky_relu(0.1)(x*y+y), relu(x*y+y,0.1))
self.assertEqual(leaky_relup(0.1)(x*y+y), relup(x*y+y,0.1))
6 changes: 6 additions & 0 deletions heyoka/_test_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@ def test_func_conversion(self):
to_sympy(core.kepDE(hx, hy, hz)), spy.Function("heyoka_kepDE")(x, y, z)
)

# relu/relup.
self.assertEqual(to_sympy(core.relu(hx)), spy.Piecewise((x, x > 0), (0., True)))
self.assertEqual(to_sympy(core.relup(hx)), spy.Piecewise((1., x > 0), (0., True)))
self.assertEqual(to_sympy(core.relu(hx, 0.1)), spy.Piecewise((x, x > 0), (x*0.1, True)))
self.assertEqual(to_sympy(core.relup(hx, 0.1)), spy.Piecewise((1., x > 0), (0.1, True)))

self.assertEqual(-1.0 * hx, from_sympy(-x))
self.assertEqual(to_sympy(-hx), -x)

Expand Down
45 changes: 28 additions & 17 deletions heyoka/expose_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,23 +278,34 @@ void expose_expression(py::module_ &m)
m.def("prod", &hey::prod, "terms"_a);

// NOTE: need explicit casts for sqrt and exp due to the presence of overloads for number.
m.def("sqrt", static_cast<hey::expression (*)(hey::expression)>(&hey::sqrt));
m.def("exp", static_cast<hey::expression (*)(hey::expression)>(&hey::exp));
m.def("log", &hey::log);
m.def("sin", &hey::sin);
m.def("cos", &hey::cos);
m.def("tan", &hey::tan);
m.def("asin", &hey::asin);
m.def("acos", &hey::acos);
m.def("atan", &hey::atan);
m.def("sinh", &hey::sinh);
m.def("cosh", &hey::cosh);
m.def("tanh", &hey::tanh);
m.def("asinh", &hey::asinh);
m.def("acosh", &hey::acosh);
m.def("atanh", &hey::atanh);
m.def("sigmoid", &hey::sigmoid);
m.def("erf", &hey::erf);
m.def("sqrt", static_cast<hey::expression (*)(hey::expression)>(&hey::sqrt), "arg"_a);
m.def("exp", static_cast<hey::expression (*)(hey::expression)>(&hey::exp), "arg"_a);
m.def("log", &hey::log, "arg"_a);
m.def("sin", &hey::sin, "arg"_a);
m.def("cos", &hey::cos, "arg"_a);
m.def("tan", &hey::tan, "arg"_a);
m.def("asin", &hey::asin, "arg"_a);
m.def("acos", &hey::acos, "arg"_a);
m.def("atan", &hey::atan, "arg"_a);
m.def("sinh", &hey::sinh, "arg"_a);
m.def("cosh", &hey::cosh, "arg"_a);
m.def("tanh", &hey::tanh, "arg"_a);
m.def("asinh", &hey::asinh, "arg"_a);
m.def("acosh", &hey::acosh, "arg"_a);
m.def("atanh", &hey::atanh, "arg"_a);
m.def("sigmoid", &hey::sigmoid, "arg"_a);
m.def("erf", &hey::erf, "arg"_a);
m.def("relu", &hey::relu, "arg"_a, "slope"_a = 0.);
m.def("relup", &hey::relup, "arg"_a, "slope"_a = 0.);

// Leaky relu wrappers.
py::class_<hey::leaky_relu> lr_class(m, "leaky_relu", py::dynamic_attr{});
lr_class.def(py::init([](double slope) { return hey::leaky_relu(slope); }), "slope"_a);
lr_class.def("__call__", &hey::leaky_relu::operator(), "arg"_a);

py::class_<hey::leaky_relup> lrp_class(m, "leaky_relup", py::dynamic_attr{});
lrp_class.def(py::init([](double slope) { return hey::leaky_relup(slope); }), "slope"_a);
lrp_class.def("__call__", &hey::leaky_relup::operator(), "arg"_a);

// NOTE: when exposing multivariate functions, we want to be able to pass
// in numerical arguments for convenience. Thus, we expose such functions taking
Expand Down
49 changes: 49 additions & 0 deletions heyoka/setup_sympy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,55 @@ void setup_sympy(py::module &m)
auto sympy_kepDE = py::object(detail::spy->attr("Function")("heyoka_kepDE"));
detail::fmap[typeid(hy::detail::kepDE_impl)] = sympy_kepDE;

// relu, relup and leaky variants.
// NOTE: these are implemented as piecewise functions:
// https://medium.com/@mathcube7/piecewise-functions-in-pythons-sympy-83f857948d3
detail::fmap[typeid(hy::detail::relu_impl)]
= [](std::unordered_map<const void *, py::object> &func_map, const hy::func &f) -> py::object {
assert(f.args().size() == 1u);

// Convert the argument to SymPy.
auto s_arg = detail::to_sympy_impl(func_map, f.args()[0]);

// Fetch the slope value.
const auto slope = f.extract<hy::detail::relu_impl>()->get_slope();

// Create the condition arg > 0.
auto cond = s_arg.attr("__gt__")(0);

// Fetch the piecewise function.
auto pw = detail::spy->attr("Piecewise");

if (slope == 0) {
return pw(py::make_tuple(s_arg, cond), py::make_tuple(0., true));
} else {
return pw(py::make_tuple(s_arg, cond), py::make_tuple(py::cast(slope) * s_arg, true));
}
};

detail::fmap[typeid(hy::detail::relup_impl)]
= [](std::unordered_map<const void *, py::object> &func_map, const hy::func &f) -> py::object {
assert(f.args().size() == 1u);

// Convert the argument to SymPy.
auto s_arg = detail::to_sympy_impl(func_map, f.args()[0]);

// Fetch the slope value.
const auto slope = f.extract<hy::detail::relup_impl>()->get_slope();

// Create the condition arg > 0.
auto cond = s_arg.attr("__gt__")(0);

// Fetch the piecewise function.
auto pw = detail::spy->attr("Piecewise");

if (slope == 0) {
return pw(py::make_tuple(1., cond), py::make_tuple(0., true));
} else {
return pw(py::make_tuple(1., cond), py::make_tuple(slope, true));
}
};

// sigmoid.
detail::fmap[typeid(hy::detail::sigmoid_impl)]
= [](std::unordered_map<const void *, py::object> &func_map, const hy::func &f) {
Expand Down

0 comments on commit 1854ffb

Please sign in to comment.