Skip to content

Commit

Permalink
Merge pull request #273 from ROCmSoftwarePlatform/rnn_optimization
Browse files Browse the repository at this point in the history
Rnn optimization
  • Loading branch information
mvermeulen authored Jun 21, 2019
2 parents f93eeca + 67c6e63 commit 15eb198
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 182 deletions.
4 changes: 3 additions & 1 deletion src/include/migraphx/op/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ struct binary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto s1 = args[0].get_shape();
auto s2 = args[1].get_shape();
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().packed() and input2.get_shape().packed())
if(s1 == s2 and input1.get_shape().packed() and input2.get_shape().packed())
{
std::transform(input1.begin(),
input1.end(),
Expand Down
2 changes: 2 additions & 0 deletions src/include/migraphx/stringutils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ inline std::string transform_string(std::string s, F f)

inline std::string to_upper(std::string s) { return transform_string(std::move(s), ::toupper); }

inline std::string to_lower(std::string s) { return transform_string(std::move(s), ::tolower); }

inline bool starts_with(const std::string& value, const std::string& prefix)
{
if(prefix.size() > value.size())
Expand Down
13 changes: 10 additions & 3 deletions src/onnx/onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ struct onnx_parser

void init_actv_func()
{
// Support name format of all lower case or the first letter capital
map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
Expand Down Expand Up @@ -871,7 +872,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
}

auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
Expand Down Expand Up @@ -962,7 +965,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
}

// need 4 activation functions
Expand Down Expand Up @@ -1089,7 +1094,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
}

// need 6 activation functions for bidirectional directions
Expand Down
8 changes: 7 additions & 1 deletion src/py/migraphx_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>

#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
Expand Down Expand Up @@ -101,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
t = as.type_enum();
n = sizeof(as());
}

});

if(n == 0)
{
MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type" + info.format);
}

auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
return n > 0 ? i / n : 0;
Expand Down
Loading

0 comments on commit 15eb198

Please sign in to comment.