Skip to content

Commit

Permalink
PRelu operator (#458)
Browse files Browse the repository at this point in the history
* add prelu operator

* clang format

* add prelu to gpu lowering

* add unit tests for the PRelu operator

* clang format

* add missing onnx file for PRelu operator

* update unit tests for prelu operator

* clang format

Co-authored-by: mvermeulen <[email protected]>
Co-authored-by: Paul Fultz II <[email protected]>
  • Loading branch information
3 people authored Mar 7, 2020
1 parent a22189d commit 63d8e40
Show file tree
Hide file tree
Showing 13 changed files with 165 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/include/migraphx/op/prelu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_PRELU_HPP
#define MIGRAPHX_GUARD_OPERATORS_PRELU_HPP

#include <migraphx/op/binary.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

struct prelu : binary<prelu>
{
auto apply() const
{
return [](auto x, auto slope) { return ((x < 0) ? (x * slope) : x); };
}
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
1 change: 1 addition & 0 deletions src/include/migraphx/operators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/prelu.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/pow.hpp>
Expand Down
1 change: 1 addition & 0 deletions src/onnx/onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ struct onnx_parser
add_binary_op("Div", op::div{});
add_binary_op("Mul", op::mul{});
add_binary_op("Pow", op::pow{});
add_binary_op("PRelu", op::prelu{});
add_binary_op("Sub", op::sub{});

add_variadic_op("Sum", op::add{});
Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ add_library(migraphx_device
device/mul_add_relu.cpp
device/pad.cpp
device/pow.cpp
device/prelu.cpp
device/reduce_max.cpp
device/reduce_mean.cpp
device/reduce_min.cpp
Expand Down
18 changes: 18 additions & 0 deletions src/targets/gpu/device/prelu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include <migraphx/gpu/device/prelu.hpp>
#include <migraphx/gpu/device/nary.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

void prelu(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto slope)
__device__ { return ((x < 0) ? (x * slope) : x); });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
21 changes: 21 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/device/prelu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PRELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PRELU_HPP

#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

void prelu(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
19 changes: 19 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/prelu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_PRELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_PRELU_HPP

#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/prelu.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

struct hip_prelu : binary_device<hip_prelu, device::prelu>
{
};

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
2 changes: 2 additions & 0 deletions src/targets/gpu/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
#include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/prelu.hpp>
#include <utility>
#include <functional>
#include <algorithm>
Expand Down Expand Up @@ -160,6 +161,7 @@ struct miopen_apply
add_generic_op<hip_pow>("pow");
add_generic_op<hip_sqdiff>("sqdiff");
add_generic_op<hip_relu>("relu");
add_generic_op<hip_prelu>("prelu");
add_generic_op<hip_sign>("sign");
add_generic_op<hip_sigmoid>("sigmoid");
add_generic_op<hip_ceil>("ceil");
Expand Down
15 changes: 15 additions & 0 deletions test/cpu_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,21 @@ TEST_CASE(log_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}

TEST_CASE(prelu_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_literal(migraphx::literal{s, {-1, 0, 2}});
auto slope = p.add_literal(migraphx::literal{s, {2, 1, 2}});
p.add_instruction(migraphx::op::prelu{}, x, slope);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2.0f, 0.0f, 2.0f};
EXPECT(migraphx::verify_range(results_vector, gold));
}

TEST_CASE(pow_test)
{
migraphx::program p;
Expand Down
15 changes: 15 additions & 0 deletions test/gpu/ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,21 @@ struct test_pow : verify_program<test_pow>
}
};

struct test_prelu_brcst : verify_program<test_prelu_brcst>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
auto x = p.add_parameter("x", s);
auto slp = p.add_parameter("slp", s);
auto r = p.add_instruction(migraphx::op::prelu{}, x, slp);
p.add_return({r});

return p;
}
};

struct test_sin : verify_program<test_sin>
{
migraphx::program create_program() const
Expand Down
16 changes: 16 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,22 @@ def pow_test():
return ([node], [arg0, arg1], [arg_out])


@onnx_test
def prelu_brcst_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 5])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT,
[2, 3, 4, 5])

node = onnx.helper.make_node(
'PRelu',
inputs=['0', '1'],
outputs=['out'],
)

return ([node], [arg0, arg1], [arg_out])


@onnx_test
def reducel1_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
Expand Down
14 changes: 14 additions & 0 deletions test/onnx/onnx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,20 @@ TEST_CASE(pow_test)
EXPECT(p == prog);
}

TEST_CASE(prelu_brcst_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{l0->get_shape().lens()}, l1);
auto ret = p.add_instruction(migraphx::op::prelu{}, l0, bl1);
p.add_return({ret});

auto prog = migraphx::parse_onnx("prelu_brcst_test.onnx");

EXPECT(p == prog);
}

TEST_CASE(reducel1_test)
{
migraphx::program p;
Expand Down
20 changes: 20 additions & 0 deletions test/onnx/prelu_brcst_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
prelu_brcst_test:w

0
1out"PReluprelu_brcst_testZ
0




Z
1


b
out




B

0 comments on commit 63d8e40

Please sign in to comment.