Skip to content

Commit

Permalink
Apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
xuchen-intel committed Dec 20, 2024
1 parent 9bf9826 commit 5690a3e
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ class TRANSFORMATIONS_API FakeConvertDecomposition;

class ov::pass::FakeConvertDecomposition : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("FakeConvertDecomposition", "0");
OPENVINO_MATCHER_PASS_RTTI("FakeConvertDecomposition");
FakeConvertDecomposition();
};
Original file line number Diff line number Diff line change
Expand Up @@ -34,49 +34,37 @@ ov::pass::FakeConvertDecomposition::FakeConvertDecomposition() {
const Output<Node> input_scale{fake_convert_node->input_value(1)};
auto input_type = data.get_element_type();

ov::NodeVector decomp_ops;
ov::pass::NodeRegistry decomp_ops;
if (input_type != input_scale.get_element_type()) {
input_type = input_scale.get_element_type();
data = std::make_shared<ov::op::v0::Convert>(data, input_type);
decomp_ops.push_back(data.get_node_shared_ptr());
data = decomp_ops.add(data.get_node_shared_ptr());
}

std::shared_ptr<Node> result;
const auto scale = std::make_shared<ov::op::v1::Multiply>(data, input_scale);
decomp_ops.push_back(scale);
const auto scale = decomp_ops.make<ov::op::v1::Multiply>(data, input_scale);
if (fake_convert_node->get_input_size() == 2) {
const auto downconvert =
std::make_shared<ov::op::v0::Convert>(scale, fake_convert_node->get_destination_element_type());
decomp_ops.push_back(downconvert);
const auto upconvert = std::make_shared<ov::op::v0::Convert>(downconvert, input_type);
decomp_ops.push_back(upconvert);
const auto downconvert = decomp_ops.make<ov::op::v0::Convert>(scale, fake_convert_node->get_destination_element_type());
const auto upconvert = decomp_ops.make<ov::op::v0::Convert>(downconvert, input_type);

result = std::make_shared<ov::op::v1::Divide>(upconvert, input_scale);
decomp_ops.push_back(result);
result = decomp_ops.make<ov::op::v1::Divide>(upconvert, input_scale);
} else {
const Output<Node> input_shift{fake_convert_node->input_value(2)};
const auto shift = std::make_shared<ov::op::v1::Subtract>(scale, input_shift);
decomp_ops.push_back(shift);
const auto shift = decomp_ops.make<ov::op::v1::Subtract>(scale, input_shift);

const auto downconvert =
std::make_shared<ov::op::v0::Convert>(shift, fake_convert_node->get_destination_element_type());
decomp_ops.push_back(downconvert);
const auto upconvert = std::make_shared<ov::op::v0::Convert>(downconvert, input_type);
decomp_ops.push_back(upconvert);
const auto downconvert = decomp_ops.make<ov::op::v0::Convert>(shift, fake_convert_node->get_destination_element_type());
const auto upconvert = decomp_ops.make<ov::op::v0::Convert>(downconvert, input_type);

const auto deshift = std::make_shared<ov::op::v1::Add>(upconvert, input_shift);
decomp_ops.push_back(deshift);
result = std::make_shared<ov::op::v1::Divide>(deshift, input_scale);
decomp_ops.push_back(result);
const auto deshift = decomp_ops.make<ov::op::v1::Add>(upconvert, input_shift);
result = decomp_ops.make<ov::op::v1::Divide>(deshift, input_scale);
}

if (result->get_output_element_type(0) != fake_convert_node->get_output_element_type(0)) {
result = std::make_shared<ov::op::v0::Convert>(result, fake_convert_node->get_output_element_type(0));
decomp_ops.push_back(result);
result = decomp_ops.make<ov::op::v0::Convert>(result, fake_convert_node->get_output_element_type(0));
}

result->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(fake_convert_node, decomp_ops);
ov::copy_runtime_info(fake_convert_node, decomp_ops.get());
ov::replace_node(m.get_match_root(), result);
return true;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,70 +43,67 @@ class FakeConvertDecompositionTest : public ov::test::TestsCommon,
result << "defaultShift=false";
return result.str();
}
};

protected:
void SetUp() override {
FakeConvertDecompositionParams params = this->GetParam();
TEST_P(FakeConvertDecompositionTest, CompareFunctions) {
FakeConvertDecompositionParams params = this->GetParam();

Shape data_shape, scale_shape, shift_shape;
element::Type_t data_prec, dst_prec;
bool default_shift;
std::tie(data_shape, scale_shape, shift_shape, data_prec, dst_prec, default_shift) = params;
Shape data_shape, scale_shape, shift_shape;
element::Type_t data_prec, dst_prec;
bool default_shift;
std::tie(data_shape, scale_shape, shift_shape, data_prec, dst_prec, default_shift) = params;

std::shared_ptr<ov::Model> f(nullptr);
{
const auto data = std::make_shared<opset1::Parameter>(data_prec, PartialShape(data_shape));
const auto scale = std::make_shared<opset1::Constant>(data_prec, scale_shape);
const auto shift = std::make_shared<opset1::Constant>(data_prec, shift_shape);
std::shared_ptr<ov::Model> model(nullptr);
{
const auto data = std::make_shared<opset1::Parameter>(data_prec, PartialShape(data_shape));
const auto scale = std::make_shared<opset1::Constant>(data_prec, scale_shape);
const auto shift = std::make_shared<opset1::Constant>(data_prec, shift_shape);

const auto fake_convert = default_shift
? std::make_shared<opset13::FakeConvert>(data, scale, dst_prec)
: std::make_shared<opset13::FakeConvert>(data, scale, shift, dst_prec);
f = std::make_shared<ov::Model>(NodeVector{fake_convert}, ParameterVector{data});
const auto fake_convert = default_shift
? std::make_shared<opset13::FakeConvert>(data, scale, dst_prec)
: std::make_shared<opset13::FakeConvert>(data, scale, shift, dst_prec);
model = std::make_shared<ov::Model>(NodeVector{fake_convert}, ParameterVector{data});

pass::Manager manager;
manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<ov::pass::FakeConvertDecomposition>();
manager.run_passes(f);
pass::Manager manager;
manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<ov::pass::FakeConvertDecomposition>();
manager.run_passes(model);

OV_ASSERT_NO_THROW(check_rt_info(f));
}
OV_ASSERT_NO_THROW(check_rt_info(model));
}

std::shared_ptr<ov::Model> f_ref(nullptr);
{
const auto input_data = std::make_shared<opset1::Parameter>(data_prec, PartialShape(data_shape));
const auto input_scale = std::make_shared<opset1::Constant>(data_prec, scale_shape);
const auto input_shift = std::make_shared<opset1::Constant>(data_prec, shift_shape);
ParameterVector params;
params.push_back(input_data);
std::shared_ptr<Node> data = input_data;

std::shared_ptr<Node> result;
const auto scale = std::make_shared<ov::op::v1::Multiply>(data, input_scale);
if (default_shift) {
const auto downconvert = std::make_shared<ov::op::v0::Convert>(scale, dst_prec);
const auto upconvert = std::make_shared<ov::op::v0::Convert>(downconvert, data_prec);

result = std::make_shared<ov::op::v1::Divide>(upconvert, input_scale);
} else {
const auto shift = std::make_shared<ov::op::v1::Subtract>(scale, input_shift);

const auto downconvert = std::make_shared<ov::op::v0::Convert>(shift, dst_prec);
const auto upconvert = std::make_shared<ov::op::v0::Convert>(downconvert, data_prec);

const auto deshift = std::make_shared<ov::op::v1::Add>(upconvert, input_shift);
result = std::make_shared<ov::op::v1::Divide>(deshift, input_scale);
}

f_ref = std::make_shared<ov::Model>(NodeVector{result}, params);
std::shared_ptr<ov::Model> model_ref(nullptr);
{
const auto input_data = std::make_shared<opset1::Parameter>(data_prec, PartialShape(data_shape));
const auto input_scale = std::make_shared<opset1::Constant>(data_prec, scale_shape);
const auto input_shift = std::make_shared<opset1::Constant>(data_prec, shift_shape);
ParameterVector params;
params.push_back(input_data);
std::shared_ptr<Node> data = input_data;

std::shared_ptr<Node> result;
const auto scale = std::make_shared<ov::op::v1::Multiply>(data, input_scale);
if (default_shift) {
const auto downconvert = std::make_shared<ov::op::v0::Convert>(scale, dst_prec);
const auto upconvert = std::make_shared<ov::op::v0::Convert>(downconvert, data_prec);

result = std::make_shared<ov::op::v1::Divide>(upconvert, input_scale);
} else {
const auto shift = std::make_shared<ov::op::v1::Subtract>(scale, input_shift);

const auto downconvert = std::make_shared<ov::op::v0::Convert>(shift, dst_prec);
const auto upconvert = std::make_shared<ov::op::v0::Convert>(downconvert, data_prec);

const auto deshift = std::make_shared<ov::op::v1::Add>(upconvert, input_shift);
result = std::make_shared<ov::op::v1::Divide>(deshift, input_scale);
}

const auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
model_ref = std::make_shared<ov::Model>(NodeVector{result}, params);
}
};

TEST_P(FakeConvertDecompositionTest, CompareFunctions) {}
const auto res = compare_functions(model, model_ref);
ASSERT_TRUE(res.first) << res.second;
}

const std::vector<element::Type_t> data_precisions = {element::Type_t::f32,
element::Type_t::f16,
Expand Down

0 comments on commit 5690a3e

Please sign in to comment.