Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
CropAndResize op (#3893)
Browse files Browse the repository at this point in the history
* Stub for CropAndResize

* Cut and pasteo

* Need a cast
  • Loading branch information
diyessi authored Nov 18, 2019
1 parent 1ac3e5c commit 90c70dd
Show file tree
Hide file tree
Showing 10 changed files with 391 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/ngraph/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ set (SRC
op/cos.hpp
op/cosh.cpp
op/cosh.hpp
op/crop_and_resize.cpp
op/crop_and_resize.hpp
op/dequantize.cpp
op/dequantize.hpp
op/divide.cpp
Expand Down
1 change: 1 addition & 0 deletions src/ngraph/ngraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/crop_and_resize.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
Expand Down
9 changes: 9 additions & 0 deletions src/ngraph/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ namespace ngraph
/// \throw std::out_of_range if the node does not have at least `input_index+1` inputs.
Input<Node> input(size_t input_index);

// Simplify migration from 0.25.1
Output<Node> input_value(size_t input_index) const;

/// \return A handle to the `input_index`th input of this node.
/// \throw std::out_of_range if the node does not have at least `input_index+1` inputs.
Input<const Node> input(size_t input_index) const;
Expand Down Expand Up @@ -650,6 +653,12 @@ namespace ngraph
return Input<const Node>(this, input_index);
}

// Simplify migration from 0.25.1
inline Output<Node> Node::input_value(size_t input_index) const
{
return input(input_index).get_source_output();
}

inline Output<Node> Node::output(size_t output_index)
{
if (output_index >= m_outputs.size())
Expand Down
198 changes: 198 additions & 0 deletions src/ngraph/op/crop_and_resize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include <vector>

#include "ngraph/op/constant.hpp"
#include "ngraph/op/crop_and_resize.hpp"

using namespace std;
using namespace ngraph;

const string op::CropAndResize::type_name{"CropAndResize"};

op::CropAndResize::CropAndResize(const Output<Node>& image,
const Output<Node>& boxes,
const Output<Node>& box_indices,
const Output<Node>& crop_size,
ResizeMethod resize_method,
float extrapolation_value)
: Op({image, boxes, box_indices, crop_size})
, m_resize_method(resize_method)
, m_extrapolation_value(extrapolation_value)
{
constructor_validate_and_infer_types();
}

void op::CropAndResize::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(this, get_input_size() == 4);
NODE_VALIDATION_CHECK(
this, m_resize_method != ResizeMethod::unspecified, "Resize method not specified");
auto image = input_value(0);
auto& image_et = image.get_element_type();

// Will override if we can determine the shape
set_output_type(0, image_et, {});

auto image_shape = image.get_partial_shape();
Dimension image_depth;
if (image_shape.is_static())
{
NODE_VALIDATION_CHECK(
this, static_cast<int64_t>(image_shape.rank()) == 4, "Image must be NHWC");
image_depth = image_shape[3];
}

auto boxes = input_value(1);
auto boxes_shape = boxes.get_partial_shape();
if (boxes_shape.is_static())
{
auto boxes_rank = boxes_shape.rank();
NODE_VALIDATION_CHECK(this, static_cast<int64_t>(boxes_rank) == 2, "Boxes must be 2d");
auto boxes_dim1 = boxes_shape[1];
NODE_VALIDATION_CHECK(
this, static_cast<int64_t>(boxes_dim1) == 4, "Second boxes dimension must be 4");
}
NODE_VALIDATION_CHECK(
this, boxes.get_element_type().is_real(), "Boxes must be real values in [0, 1]");

auto box_indices = input_value(2);
auto box_indices_shape = box_indices.get_partial_shape();
Dimension num_boxes;
if (box_indices_shape.is_static())
{
NODE_VALIDATION_CHECK(this,
static_cast<int64_t>(box_indices_shape.rank()) == 1,
"Box indices must have rank 1");
num_boxes = box_indices_shape[0];
}
NODE_VALIDATION_CHECK(
this, box_indices.get_element_type().is_integral(), "Box indices must be integers");

auto crop_size = input_value(3);
auto crop_size_shape = crop_size.get_partial_shape();
auto crop_size_rank = crop_size_shape.rank();
NODE_VALIDATION_CHECK(this,
crop_size_shape.is_static() || crop_size_rank.is_dynamic(),
"Dynamic crop_size not supported");

NODE_VALIDATION_CHECK(
this, static_cast<int64_t>(crop_size_rank) == 1, "crop_size must be a vector");
NODE_VALIDATION_CHECK(this,
static_cast<int64_t>(crop_size_shape[0]) == 2,
"crop_size must be a vector of length 2");
auto& crop_size_et = crop_size.get_element_type();
NODE_VALIDATION_CHECK(this, crop_size_et.is_integral(), "crops_size must be integral");
auto crop_size_node = crop_size.get_node_shared_ptr();
NODE_VALIDATION_CHECK(this, crop_size_node->is_constant(), "crop_size must be a constant");
auto crop_size_const = static_pointer_cast<op::Constant>(crop_size_node);
if (crop_size_et == element::i8)
{
auto v = crop_size_const->get_vector<int8_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::u8)
{
auto v = crop_size_const->get_vector<uint8_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::i16)
{
auto v = crop_size_const->get_vector<int16_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::u16)
{
auto v = crop_size_const->get_vector<uint16_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::i32)
{
auto v = crop_size_const->get_vector<int32_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::u32)
{
auto v = crop_size_const->get_vector<uint32_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::i64)
{
auto v = crop_size_const->get_vector<int64_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::u64)
{
auto v = crop_size_const->get_vector<uint64_t>();
set_output_type(
0,
image_et,
{num_boxes, static_cast<int64_t>(v[0]), static_cast<int64_t>(v[1]), image_depth});
}
else
{
NODE_VALIDATION_CHECK(this, false, "Unknown integral type for crop size");
}
}

shared_ptr<Node> op::CropAndResize::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<CropAndResize>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
m_resize_method,
m_extrapolation_value);
}

static const vector<pair<string, op::CropAndResize::ResizeMethod>>& get_resize_pairs()
{
static vector<pair<string, op::CropAndResize::ResizeMethod>> pairs{
{"unspecified", op::CropAndResize::ResizeMethod::unspecified},
{"bilinear", op::CropAndResize::ResizeMethod::bilinear},
{"nearest", op::CropAndResize::ResizeMethod::nearest}};
return pairs;
}

const string& ngraph::as_string(op::CropAndResize::ResizeMethod resize_method)
{
for (auto& p : get_resize_pairs())
{
if (p.second == resize_method)
{
return p.first;
}
}
throw ngraph_error("Internal error: unhandled resize method");
}

namespace ngraph
{
template <>
op::CropAndResize::ResizeMethod as_type<op::CropAndResize::ResizeMethod>(const std::string& s)
{
for (auto& p : get_resize_pairs())
{
if (p.first == s)
{
return p.second;
}
}
throw ngraph_error("Internal error: unhandled resize method name");
}
}
76 changes: 76 additions & 0 deletions src/ngraph/op/crop_and_resize.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include "ngraph/op/op.hpp"

namespace ngraph
{
namespace op
{
class CropAndResize : public Op
{
public:
enum class ResizeMethod
{
unspecified,
bilinear,
nearest
};

NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a crop and resize operation.
CropAndResize() = default;

/// \param image [N, H, W, C]
/// \param boxes [NUM_BOXES, 4] where boxes[box] is [y1, x1, y2, x2] each in [0, 1]
/// \param box_indices [NUM_BOXES] in [0, N)
/// \param crop_size [crop_height, crop_width]
CropAndResize(const Output<Node>& image,
const Output<Node>& boxes,
const Output<Node>& box_indices,
const Output<Node>& crop_size,
ResizeMethod resize_method,
float extrapolation_value);

void validate_and_infer_types() override;

std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;

ResizeMethod get_resize_method() const { return m_resize_method; }
void set_resize_method(ResizeMethod resize_method) { m_resize_method = resize_method; }
float get_extrapolation_value() const { return m_extrapolation_value; }
void set_extrapolation_value(float extrapolation_value)
{
m_extrapolation_value = extrapolation_value;
}

private:
ResizeMethod m_resize_method{ResizeMethod::unspecified};
float m_extrapolation_value{0};
};
}

const std::string& as_string(op::CropAndResize::ResizeMethod);
template <typename T>
T as_type(const std::string&);

template <>
op::CropAndResize::ResizeMethod as_type<op::CropAndResize::ResizeMethod>(const std::string&);
}
1 change: 1 addition & 0 deletions src/ngraph/op/op_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ NGRAPH_OP(ConvolutionBackpropData, ngraph::op)
NGRAPH_OP(ConvolutionBackpropFilters, ngraph::op)
NGRAPH_OP(Cos, ngraph::op)
NGRAPH_OP(Cosh, ngraph::op)
NGRAPH_OP(CropAndResize, ngraph::op)
NGRAPH_OP(Dequantize, ngraph::op)
NGRAPH_OP(Divide, ngraph::op)
NGRAPH_OP(Dot, ngraph::op)
Expand Down
5 changes: 5 additions & 0 deletions src/ngraph/runtime/interpreter/int_executable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,11 @@ class ngraph::runtime::interpreter::INTExecutable : public Executable
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::CropAndResize:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::Dequantize:
{
const op::Dequantize* dequantize = static_cast<const op::Dequantize*>(&node);
Expand Down
17 changes: 17 additions & 0 deletions src/ngraph/serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/crop_and_resize.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
Expand Down Expand Up @@ -1117,6 +1118,15 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Cosh>(args[0]);
break;
}
case OP_TYPEID::CropAndResize:
{
auto resize_method =
as_type<op::CropAndResize::ResizeMethod>(node_js.at("resize_method").get<string>());
auto extrapolation_value = node_js.at("extrapolation_value").get<float>();
node = make_shared<op::CropAndResize>(
args[0], args[1], args[2], args[3], resize_method, extrapolation_value);
break;
}
case OP_TYPEID::DepthToSpace:
{
auto block_size = node_js.at("block_size").get<size_t>();
Expand Down Expand Up @@ -2363,6 +2373,13 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Cosh: { break;
}
case OP_TYPEID::CropAndResize:
{
auto tmp = static_cast<const op::CropAndResize*>(&n);
node["resize_method"] = as_string(tmp->get_resize_method());
node["extrapolation_value"] = tmp->get_extrapolation_value();
break;
}
case OP_TYPEID::Dequantize:
{
auto tmp = dynamic_cast<const op::Dequantize*>(&n);
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ set(SRC
type_prop/convert.cpp
type_prop/convolution.cpp
type_prop/convolution_bias.cpp
type_prop/crop_and_resize.cpp
type_prop/depth_to_space.cpp
type_prop/dequantize.cpp
type_prop/dot.cpp
Expand Down
Loading

0 comments on commit 90c70dd

Please sign in to comment.