Skip to content

Commit

Permalink
[core]Migrate NotEqual operator to new API (#20648)
Browse files Browse the repository at this point in the history
* Migrate NotEqual operator to new API

* Remove `visit_attributes` is same as base

---------

Co-authored-by: Michal Lukaszewski <[email protected]>
  • Loading branch information
praasz and mlukasze authored Oct 31, 2023
1 parent ce8ac6f commit 246410b
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 85 deletions.
5 changes: 1 addition & 4 deletions src/core/include/openvino/op/not_equal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,8 @@ class OPENVINO_API NotEqual : public util::BinaryElementwiseComparison {

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
bool visit_attributes(AttributeVisitor& visitor) override;
};
} // namespace v1
} // namespace op
Expand Down
25 changes: 11 additions & 14 deletions src/core/reference/include/openvino/reference/not_equal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,30 @@

#pragma once

#if defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wfloat-equal"
#endif

#include <cstddef>
#include <functional>

#include "openvino/core/shape.hpp"
#include "openvino/op/util/attr_types.hpp"
#include "openvino/reference/autobroadcast_binop.hpp"

namespace ov {
namespace reference {
// Use custom implementation as function instead std::not_equal_to functor, gives smaller binary size.
// If removed or replace check impact on library binary size.
namespace func {
template <class T>
bool not_equal(const T lhs, const T rhs) {
return lhs != rhs;
}
} // namespace func

template <typename T, typename U>
void not_equal(const T* arg0,
const T* arg1,
U* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec) {
autobroadcast_binop(arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> U {
return static_cast<U>(x != y);
});
autobroadcast_binop(arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, func::not_equal<T>);
}
} // namespace reference
} // namespace ov

#if defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
118 changes: 51 additions & 67 deletions src/core/src/op/not_equal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,95 +2,79 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "ngraph/op/not_equal.hpp"
#include "openvino/op/not_equal.hpp"

#include "element_visitor.hpp"
#include "itt.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/validation_util.hpp"
#include "openvino/reference/not_equal.hpp"
#include "utils.hpp"

using namespace std;
using namespace ngraph;

OPENVINO_SUPPRESS_DEPRECATED_START
namespace not_equalop {
namespace {
template <element::Type_t ET>
bool evaluate(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& out,
const op::AutoBroadcastSpec& broadcast_spec) {
ov::reference::not_equal(arg0->get_data_ptr<ET>(),
arg1->get_data_ptr<ET>(),
out->get_data_ptr<element::Type_t::boolean>(),
arg0->get_shape(),
arg1->get_shape(),
namespace ov {
namespace op {
namespace not_equal {
struct Evaluate : element::NoAction<bool> {
using element::NoAction<bool>::visit;
template <element::Type_t ET, class T = fundamental_type_for<ET>>
static result_type visit(const Tensor& in0,
const Tensor& in1,
Tensor& out,
const Shape& shape0,
const Shape& shape1,
const AutoBroadcastSpec& broadcast_spec) {
reference::not_equal(in0.data<const T>(),
in1.data<const T>(),
out.data<fundamental_type_for<element::boolean>>(),
shape0,
shape1,
broadcast_spec);
return true;
}

bool evaluate_not_equal(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& out,
const op::AutoBroadcastSpec& broadcast_spec) {
bool rc = true;
out->set_broadcast(broadcast_spec, arg0, arg1, element::boolean);
switch (arg0->get_element_type()) {
OPENVINO_TYPE_CASE(evaluate_not_equal, boolean, arg0, arg1, out, broadcast_spec);
OPENVINO_TYPE_CASE(evaluate_not_equal, i32, arg0, arg1, out, broadcast_spec);
OPENVINO_TYPE_CASE(evaluate_not_equal, i64, arg0, arg1, out, broadcast_spec);
OPENVINO_TYPE_CASE(evaluate_not_equal, u32, arg0, arg1, out, broadcast_spec);
OPENVINO_TYPE_CASE(evaluate_not_equal, u64, arg0, arg1, out, broadcast_spec);
OPENVINO_TYPE_CASE(evaluate_not_equal, f16, arg0, arg1, out, broadcast_spec);
OPENVINO_TYPE_CASE(evaluate_not_equal, f32, arg0, arg1, out, broadcast_spec);
default:
rc = false;
break;
return true;
}
return rc;
}
} // namespace
} // namespace not_equalop
};
} // namespace not_equal

// ----------------------------------- v1 --------------------------------------
op::v1::NotEqual::NotEqual(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& auto_broadcast)
namespace v1 {
NotEqual::NotEqual(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast) {
constructor_validate_and_infer_types();
}

shared_ptr<Node> op::v1::NotEqual::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> NotEqual::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v1_NotEqual_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<op::v1::NotEqual>(new_args.at(0), new_args.at(1), this->get_autob());
return std::make_shared<NotEqual>(new_args.at(0), new_args.at(1), get_autob());
}

bool op::v1::NotEqual::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool NotEqual::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v1_NotEqual_evaluate);
OPENVINO_SUPPRESS_DEPRECATED_START
OPENVINO_ASSERT(validate_host_tensor_vector(outputs, 1) && validate_host_tensor_vector(inputs, 2));
OPENVINO_SUPPRESS_DEPRECATED_END
return not_equalop::evaluate_not_equal(inputs[0], inputs[1], outputs[0], get_autob());
OPENVINO_ASSERT(outputs.size() == 1);

outputs[0].set_shape(infer_broadcast_shape(this, inputs));
using namespace ov::element;
return IfTypeOf<boolean, f16, f32, i32, i64, u32, u64>::apply<not_equal::Evaluate>(inputs[0].get_element_type(),
inputs[0],
inputs[1],
outputs[0],
inputs[0].get_shape(),
inputs[1].get_shape(),
get_autob());
}

bool op::v1::NotEqual::has_evaluate() const {
bool NotEqual::has_evaluate() const {
OV_OP_SCOPE(v1_NotEqual_has_evaluate);
switch (get_input_element_type(0)) {
case ngraph::element::boolean:
case ngraph::element::i32:
case ngraph::element::i64:
case ngraph::element::u32:
case ngraph::element::u64:
case ngraph::element::f16:
case ngraph::element::f32:
case element::boolean:
case element::f16:
case element::f32:
case element::i32:
case element::i64:
case element::u32:
case element::u64:
return true;
default:
break;
return false;
}
return false;
}

bool op::v1::NotEqual::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v1_NotEqual_visit_attributes);
BinaryElementwiseComparison::visit_attributes(visitor);
return true;
}
} // namespace v1
} // namespace op
} // namespace ov

0 comments on commit 246410b

Please sign in to comment.