Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
[Py] Added elu operator to Python API. (#3236)
Browse files Browse the repository at this point in the history
* Added elu operator to Python API.

* Added missing file.

* Specified elu function description.

* Expand docstring

* [Py] Added test with scalar for elu operator.

* Bugfix

*  [Py] Changed input type in elu test.

* Update test_ops_binary.py

* [Py] Syntax bugfix.

* [Py] Added elu operator to list in documentation.
  • Loading branch information
Ewa Tusień authored and diyessi committed Jul 22, 2019
1 parent a58d3bc commit e495561
Show file tree
Hide file tree
Showing 12 changed files with 179 additions and 3 deletions.
1 change: 1 addition & 0 deletions doc/sphinx/source/python_api/_autosummary/ngraph.ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ ngraph.ops
cosh
divide
dot
elu
equal
exp
floor
Expand Down
1 change: 1 addition & 0 deletions python/ngraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ngraph.ops import cosh
from ngraph.ops import divide
from ngraph.ops import dot
from ngraph.ops import elu
from ngraph.ops import equal
from ngraph.ops import exp
from ngraph.ops import floor
Expand Down
1 change: 1 addition & 0 deletions python/ngraph/impl/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from _pyngraph.op import Cosh
from _pyngraph.op import Divide
from _pyngraph.op import Dot
from _pyngraph.op import Elu
from _pyngraph.op import Equal
from _pyngraph.op import Exp
from _pyngraph.op import Floor
Expand Down
22 changes: 20 additions & 2 deletions python/ngraph/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \
Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \
GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \
Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \
Expand All @@ -35,7 +35,7 @@
from ngraph.utils.input_validation import assert_list_of_ints
from ngraph.utils.reduction import get_reduction_axes
from ngraph.utils.types import NumericType, NumericData, TensorShape, make_constant_node, \
NodeInput, ScalarData
NodeInput, ScalarData, as_node
from ngraph.utils.types import get_element_type


Expand All @@ -60,6 +60,24 @@ def constant(value, dtype=None, name=None): # type: (NumericData, NumericType,
return make_constant_node(value, dtype)


@nameable_op
def elu(data, alpha, name=None): # type: (NodeInput, NodeInput, str) -> Node
"""Perform Exponential Linear Unit operation element-wise on data from input node.
Computes exponential linear: alpha * (exp(data) - 1) if < 0, data otherwise.
For more information refer to:
`Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
<http://arxiv.org/abs/1511.07289>`_
:param data: Input tensor. One of: input node, array or scalar.
:param alpha: Multiplier for negative values. One of: input node or scalar value.
:param name: Optional output node name.
:return: The new node performing an ELU operation on its input data element-wise.
"""
return Elu(as_node(data), as_node(alpha))


# Unary ops
@unary_op
def absolute(node, name=None): # type: (NodeInput, str) -> Node
Expand Down
30 changes: 30 additions & 0 deletions python/pyngraph/ops/elu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ngraph/op/fused/elu.hpp"
#include "pyngraph/ops/elu.hpp"

namespace py = pybind11;

void regclass_pyngraph_op_Elu(py::module m)
{
py::class_<ngraph::op::Elu, std::shared_ptr<ngraph::op::Elu>, ngraph::op::Op> elu(m, "Elu");
elu.doc() = "ngraph.impl.op.Elu wraps ngraph::op::Elu";
elu.def(py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
23 changes: 23 additions & 0 deletions python/pyngraph/ops/elu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include <pybind11/pybind11.h>

namespace py = pybind11;

void regclass_pyngraph_op_Elu(py::module m);
30 changes: 30 additions & 0 deletions python/pyngraph/ops/fused/elu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ngraph/op/fused/elu.hpp"
#include "pyngraph/ops/fused/elu.hpp"

namespace py = pybind11;

void regclass_pyngraph_op_Elu(py::module m)
{
py::class_<ngraph::op::Elu, std::shared_ptr<ngraph::op::Elu>, ngraph::op::Op> elu(m, "Elu");
elu.doc() = "ngraph.impl.op.Elu wraps ngraph::op::Elu";
elu.def(py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
1 change: 1 addition & 0 deletions python/pyngraph/ops/regmodule_pyngraph_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Cosh(m_op);
regclass_pyngraph_op_Divide(m_op);
regclass_pyngraph_op_Dot(m_op);
regclass_pyngraph_op_Elu(m_op);
regclass_pyngraph_op_Equal(m_op);
regclass_pyngraph_op_Exp(m_op);
regclass_pyngraph_op_Floor(m_op);
Expand Down
1 change: 1 addition & 0 deletions python/pyngraph/ops/regmodule_pyngraph_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "pyngraph/ops/cosh.hpp"
#include "pyngraph/ops/divide.hpp"
#include "pyngraph/ops/dot.hpp"
#include "pyngraph/ops/elu.hpp"
#include "pyngraph/ops/equal.hpp"
#include "pyngraph/ops/exp.hpp"
#include "pyngraph/ops/floor.hpp"
Expand Down
1 change: 1 addition & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def cpp_flag(compiler):
'pyngraph/ops/ceiling.cpp',
'pyngraph/ops/divide.cpp',
'pyngraph/ops/dot.cpp',
'pyngraph/ops/elu.cpp',
'pyngraph/ops/equal.cpp',
'pyngraph/ops/exp.cpp',
'pyngraph/ops/floor.cpp',
Expand Down
69 changes: 69 additions & 0 deletions python/test/ngraph/test_ops_fused.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# ******************************************************************************
# Copyright 2017-2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import numpy as np

import ngraph as ng
from test.ngraph.util import get_runtime


def test_elu_operator():
runtime = get_runtime()

data_shape = [2, 2]
alpha_shape = [2]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
parameter_alpha = ng.parameter(alpha_shape, name='Alpha', dtype=np.float32)

model = ng.elu(parameter_data, parameter_alpha)
computation = runtime.computation(model, parameter_data, parameter_alpha)

value_data = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
value_alpha = np.array([3, 3], dtype=np.float32)

result = computation(value_data, value_alpha)
expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32)
assert np.allclose(result, expected)


def test_elu_operator_with_scalar_and_array():
runtime = get_runtime()

data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
alpha_value = np.float32(3)

model = ng.elu(data_value, alpha_value)
computation = runtime.computation(model)

result = computation()
expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32)
assert np.allclose(result, expected)


def test_elu_operator_with_scalar():
runtime = get_runtime()

data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
alpha_value = np.float32(3)

data_shape = [2, 2]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)

model = ng.elu(parameter_data, alpha_value)
computation = runtime.computation(model, parameter_data)

result = computation(data_value)
expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32)
assert np.allclose(result, expected)
2 changes: 1 addition & 1 deletion src/ngraph/op/fused/elu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ NodeVector op::Elu::decompose_op() const
auto data = get_argument(0);
auto alpha_node = get_argument(1);

alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape());
alpha_node = ngraph::op::numpy_style_broadcast(alpha_node, data->get_shape());

shared_ptr<ngraph::Node> zero_node =
builder::make_constant(data->get_element_type(), data->get_shape(), 0);
Expand Down

0 comments on commit e495561

Please sign in to comment.