Skip to content

Commit

Permalink
Selu operator (#642)
Browse files Browse the repository at this point in the history
* code backup

* clang format

* support for sele operator

* clang format

* added an onnx unit test for selu

* clang format

* add more unit tests for the selu operation

Co-authored-by: mvermeulen <[email protected]>
  • Loading branch information
scxiao and mvermeulen authored Sep 25, 2020
1 parent 5494024 commit 48fa934
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 3 deletions.
44 changes: 44 additions & 0 deletions src/onnx/onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ struct onnx_parser
add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("Selu", &onnx_parser::parse_selu);
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Split", &onnx_parser::parse_split);
Expand Down Expand Up @@ -1531,6 +1532,49 @@ struct onnx_parser

return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
}

instruction_ref
parse_selu(const std::string&, const node_info& info, std::vector<instruction_ref> args)
{
auto type = args[0]->get_shape().type();
auto lens = args[0]->get_shape().lens();
float alpha = 1.67326f;
if(contains(info.attributes, "alpha"))
{
alpha = info.attributes.at("alpha").f();
}

float gamma = 1.0507f;
if(contains(info.attributes, "gamma"))
{
gamma = info.attributes.at("gamma").f();
}

auto l_alpha = prog.add_literal({{type, {1}}, {alpha}});
auto l_gamma = prog.add_literal({{type, {1}}, {gamma / 2.0f}});
if(lens != std::vector<std::size_t>{1})
{
l_alpha =
prog.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_alpha);
l_gamma =
prog.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_gamma);
}

auto sign_x = prog.add_instruction(make_op("sign"), args[0]);
auto exp_x = prog.add_instruction(make_op("exp"), args[0]);

auto alpha_ex = prog.add_instruction(make_op("mul"), l_alpha, exp_x);
auto aex_alpha = prog.add_instruction(make_op("sub"), alpha_ex, l_alpha);

auto ins1 = prog.add_instruction(make_op("add"), aex_alpha, args[0]);
auto ins2 = prog.add_instruction(make_op("sub"), aex_alpha, args[0]);

auto sign2 = prog.add_instruction(make_op("mul"), sign_x, ins2);
auto ins_sub = prog.add_instruction(make_op("sub"), ins1, sign2);

return prog.add_instruction(make_op("mul"), ins_sub, l_gamma);
}

// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
instruction_ref
Expand Down
14 changes: 14 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2390,6 +2390,20 @@ def reshape_non_standard_test():
return ([trans, res], [x], [y])


@onnx_test
def selu_test():
x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [2, 3])
y = helper.make_tensor_value_info('y', TensorProto.DOUBLE, [2, 3])

node = onnx.helper.make_node('Selu',
inputs=['x'],
outputs=['y'],
alpha=0.3,
gamma=0.5)

return ([node], [x], [y])


@onnx_test
def shape_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
Expand Down
32 changes: 32 additions & 0 deletions test/onnx/onnx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1767,6 +1767,38 @@ TEST_CASE(round_test)
EXPECT(p == prog);
}

TEST_CASE(selu_test)
{
migraphx::program p;
std::vector<std::size_t> lens = {2, 3};
migraphx::shape s{migraphx::shape::double_type, lens};
auto x = p.add_parameter("x", s);

migraphx::shape ls{migraphx::shape::double_type, {1}};
auto la = p.add_literal({ls, {0.3}});
auto lg = p.add_literal({ls, {0.25}});
auto mbla = p.add_instruction(migraphx::op::multibroadcast{lens}, la);
auto mblg = p.add_instruction(migraphx::op::multibroadcast{lens}, lg);

auto sign_x = p.add_instruction(migraphx::op::sign{}, x);
auto exp_x = p.add_instruction(migraphx::op::exp{}, x);

auto mlax = p.add_instruction(migraphx::op::mul{}, mbla, exp_x);
auto smlax = p.add_instruction(migraphx::op::sub{}, mlax, mbla);

auto item1 = p.add_instruction(migraphx::op::add{}, smlax, x);
auto item2 = p.add_instruction(migraphx::op::sub{}, smlax, x);

auto sitem2 = p.add_instruction(migraphx::op::mul{}, sign_x, item2);
auto item12 = p.add_instruction(migraphx::op::sub{}, item1, sitem2);
auto r = p.add_instruction(migraphx::op::mul{}, item12, mblg);
p.add_return({r});

auto prog = migraphx::parse_onnx("selu_test.onnx");

EXPECT(p == prog);
}

TEST_CASE(shape_test)
{
migraphx::program p;
Expand Down
Binary file added test/onnx/selu_test.onnx
Binary file not shown.
19 changes: 19 additions & 0 deletions test/onnx/verify_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,25 @@ TEST_CASE(gather_elements)
EXPECT(migraphx::verify_range(result_vector, gold));
}

TEST_CASE(selu_test)
{
migraphx::program p = migraphx::parse_onnx("selu_test.onnx");
p.compile(migraphx::cpu::target{});

migraphx::shape xs{migraphx::shape::double_type, {2, 3}};
std::vector<double> x_data = {1.1, 2.1, 0.0, -1.3, -5.3, 12.0};

migraphx::program::parameter_map pp;
pp["x"] = migraphx::argument(xs, x_data.data());

auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<float> gold = {0.55, 1.05, 0, -0.10912, -0.149251, 6};
EXPECT(migraphx::verify_range(result_vector, gold));
}

TEST_CASE(where_test)
{
migraphx::program p = migraphx::parse_onnx("where_test.onnx");
Expand Down
3 changes: 0 additions & 3 deletions test/py/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_not_3d_cpu')
backend_test.exclude(r'test_not_4d_cpu')
backend_test.exclude(r'test_pow_types_*')
backend_test.exclude(r'test_selu_cpu')
backend_test.exclude(r'test_selu_default_cpu')
backend_test.exclude(r'test_selu_example_cpu')
backend_test.exclude(r'test_size_cpu')
backend_test.exclude(r'test_size_example_cpu')
backend_test.exclude(r'test_softmax_cross_entropy_*')
Expand Down

0 comments on commit 48fa934

Please sign in to comment.