Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
Atan2 op (#3859)
Browse files Browse the repository at this point in the history
* Atan2 op
  • Loading branch information
diyessi authored Nov 7, 2019
1 parent 1d9a495 commit 6e405b8
Show file tree
Hide file tree
Showing 18 changed files with 301 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/ngraph/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ set (SRC
op/asin.hpp
op/atan.cpp
op/atan.hpp
op/atan2.cpp
op/atan2.hpp
op/avg_pool.cpp
op/avg_pool.hpp
op/batch_norm.cpp
Expand Down
1 change: 1 addition & 0 deletions src/ngraph/ngraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/atan2.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
Expand Down
43 changes: 43 additions & 0 deletions src/ngraph/op/atan2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include "ngraph/op/atan2.hpp"

using namespace std;
using namespace ngraph;

const string op::Atan2::type_name{"Atan2"};

op::Atan2::Atan2(const Output<Node>& y, const Output<Node>& x, const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic(y, x, autob)
{
constructor_validate_and_infer_types();
}

shared_ptr<Node> op::Atan2::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Atan2>(new_args.at(0), new_args.at(1), this->get_autob());
}

void op::Atan2::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
throw ngraph_error("Autodiff not supported for Atan2");
}
50 changes: 50 additions & 0 deletions src/ngraph/op/atan2.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include <memory>

#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"

namespace ngraph
{
namespace op
{
/// \brief Elementwise full arctan operation
class Atan2 : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Atan2() = default;

/// \brief atan2(y,x) is the angle from the origin to the point (x,y) (note reversed order).
///
/// \param y
/// \param x
Atan2(const Output<Node>& y,
const Output<Node>& x,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;

protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
}
1 change: 1 addition & 0 deletions src/ngraph/op/op_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ NGRAPH_OP(ArgMax, ngraph::op)
NGRAPH_OP(ArgMin, ngraph::op)
NGRAPH_OP(Asin, ngraph::op)
NGRAPH_OP(Atan, ngraph::op)
NGRAPH_OP(Atan2, ngraph::op)
NGRAPH_OP(AvgPool, ngraph::op)
NGRAPH_OP(AvgPoolBackprop, ngraph::op)
NGRAPH_OP(BatchMatMul, ngraph::op)
Expand Down
2 changes: 2 additions & 0 deletions src/ngraph/pass/cse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/atan2.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/ceiling.hpp"
#include "ngraph/op/constant.hpp"
Expand Down Expand Up @@ -151,6 +152,7 @@ static unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node
{TI(op::Acos), cse_unarywise},
{TI(op::Asin), cse_unarywise},
{TI(op::Atan), cse_unarywise},
{TI(op::Atan2), cse_binarywise},
{TI(op::Ceiling), cse_unarywise},
{TI(op::Constant), cse_constant},
{TI(op::Cos), cse_unarywise},
Expand Down
9 changes: 9 additions & 0 deletions src/ngraph/runtime/cpu/cpu_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "ngraph/op/and.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/atan2.hpp"
#include "ngraph/op/ceiling.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/cos.hpp"
Expand Down Expand Up @@ -75,6 +76,7 @@
#include "ngraph/runtime/cpu/kernel/and.hpp"
#include "ngraph/runtime/cpu/kernel/asin.hpp"
#include "ngraph/runtime/cpu/kernel/atan.hpp"
#include "ngraph/runtime/cpu/kernel/atan2.hpp"
#include "ngraph/runtime/cpu/kernel/broadcast.hpp"
#include "ngraph/runtime/cpu/kernel/ceil.hpp"
#include "ngraph/runtime/cpu/kernel/cos.hpp"
Expand Down Expand Up @@ -311,6 +313,12 @@ namespace ngraph
BUILD_UNARY_ELEMWISE_FUNCTOR(runtime::cpu::kernel::atan);
}

template <>
void Builder::BUILDER_DECL(ngraph::op::Atan2)
{
BUILD_BINARY_ELEMWISE_FUNCTOR(runtime::cpu::kernel::atan2);
}

template <>
void Builder::BUILDER_DECL(ngraph::op::Ceiling)
{
Expand Down Expand Up @@ -628,6 +636,7 @@ namespace ngraph
REGISTER_OP_BUILDER(Acos);
REGISTER_OP_BUILDER(Asin);
REGISTER_OP_BUILDER(Atan);
REGISTER_OP_BUILDER(Atan2);
REGISTER_OP_BUILDER(Ceiling);
REGISTER_OP_BUILDER(Cos);
REGISTER_OP_BUILDER(Cosh)
Expand Down
16 changes: 16 additions & 0 deletions src/ngraph/runtime/cpu/cpu_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/atan2.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
Expand Down Expand Up @@ -1829,6 +1830,21 @@ namespace ngraph
writer.block_end();
}

template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Atan2)
{
(void)external_function;
(void)node;
writer.block_begin();
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = atan2(" << args[0].get_name() << ", "
<< args[1].get_name() << "[i]);\n";
writer.block_end();
writer.block_end();
}

static void emitArgMinArgMax(const std::vector<TensorViewWrapper>& args,
const std::vector<TensorViewWrapper>& out,
size_t reduction_axis,
Expand Down
2 changes: 2 additions & 0 deletions src/ngraph/runtime/cpu/cpu_external_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/atan2.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
Expand Down Expand Up @@ -365,6 +366,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::ArgMax), &runtime::cpu::CPU_Emitter::emit<op::ArgMax>},
{TI(ngraph::op::Acos), &runtime::cpu::CPU_Emitter::emit<op::Acos>},
{TI(ngraph::op::Atan), &runtime::cpu::CPU_Emitter::emit<op::Atan>},
{TI(ngraph::op::Atan2), &runtime::cpu::CPU_Emitter::emit<op::Atan2>},
{TI(ngraph::op::ReplaceSlice), &runtime::cpu::CPU_Emitter::emit<op::ReplaceSlice>},
{TI(ngraph::op::UpdateSlice), &runtime::cpu::CPU_Emitter::emit<op::UpdateSlice>},
{TI(ngraph::op::OneHot), &runtime::cpu::CPU_Emitter::emit<op::OneHot>},
Expand Down
56 changes: 56 additions & 0 deletions src/ngraph/runtime/cpu/kernel/atan2.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include <cmath>

#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>

#include "ngraph/runtime/cpu/cpu_executor.hpp"

namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
template <typename ElementType>
void atan2(void* input0, void* input1, void* output, size_t count, int arena)
{
Eigen::array<Eigen::Index, 1> out_dims, in_dims;

out_dims[0] = in_dims[0] = count;

Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> out(
static_cast<ElementType*>(output), out_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> in0(
static_cast<ElementType*>(input0), in_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> in1(
static_cast<ElementType*>(input1), in_dims);

out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) =
in0.binaryExpr(in1, [](ElementType y, ElementType x) {
return static_cast<ElementType>(std::atan2(y, x));
});
}
}
}
}
}
10 changes: 10 additions & 0 deletions src/ngraph/runtime/generic_cpu/gcpu_executable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
#include "ngraph/runtime/reference/argmin.hpp"
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/atan2.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_mat_mul.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
Expand Down Expand Up @@ -357,6 +358,15 @@ class ngraph::runtime::gcpu::GCPUExecutable : public Executable
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Atan2:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::atan2<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::AvgPool:
{
const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node);
Expand Down
1 change: 1 addition & 0 deletions src/ngraph/runtime/intelgpu/intelgpu_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2050,6 +2050,7 @@ shared_ptr<runtime::Executable>
break;
}
case OP_TYPEID::AllReduce:
case OP_TYPEID::Atan2:
case OP_TYPEID::BatchMatMul:
case OP_TYPEID::BroadcastDistributed:
case OP_TYPEID::BroadcastLike:
Expand Down
10 changes: 10 additions & 0 deletions src/ngraph/runtime/interpreter/int_executable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
#include "ngraph/runtime/reference/argmin.hpp"
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/atan2.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_mat_mul.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
Expand Down Expand Up @@ -367,6 +368,15 @@ class ngraph::runtime::interpreter::INTExecutable : public Executable
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Atan2:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::atan2<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::AvgPool:
{
const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node);
Expand Down
1 change: 1 addition & 0 deletions src/ngraph/runtime/plaidml/unit_test.manifest
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ scatter_nd_add_2d_to_3d

# To be triaged -- bad kernels, numerical accuracy, edge conditions,
# unimplemented functionality, &c
atan2
cos
erf
sin
Expand Down
38 changes: 38 additions & 0 deletions src/ngraph/runtime/reference/atan2.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include <cmath>
#include <cstddef>

namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename X, typename Y, typename Z>
void atan2(const X* py, const Y* px, Z* pout, size_t count)
{
for (size_t i = 0; i < count; i++)
{
*pout++ = static_cast<Z>(std::atan2(*py++, *px++));
}
}
}
}
}
Loading

0 comments on commit 6e405b8

Please sign in to comment.