Skip to content

Commit abf4b77

Browse files
dsqx71piiswrong
authored andcommitted
[OP] add BilinearSamplingOp and GridGeneratorOp
fix inconsistency
1 parent e086a1f commit abf4b77

File tree

9 files changed

+1305
-0
lines changed

9 files changed

+1305
-0
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,4 @@ List of Contributors
120120
* [Wei Wu](https://github.com/lazyparser)
121121
* [Shishi Duan](https://github.com/burness)
122122
* [Yu Du](https://github.com/Answeror)
123+
* [Xu Dong](https://github.com/dsqx71)

src/operator/bilinear_sampler-inl.h

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
/*!
2+
* Copyright (c) 2017 by Contributors
3+
* \file bilinear_Sampler-inl.h
4+
* \brief
5+
* \author Xu Dong
6+
*/
7+
#ifndef MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_
8+
#define MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_
9+
10+
#include <dmlc/logging.h>
11+
#include <dmlc/parameter.h>
12+
#include <mxnet/operator.h>
13+
#include <vector>
14+
#include <map>
15+
#include <string>
16+
#include <utility>
17+
#include "./operator_common.h"
18+
19+
namespace mxnet {
20+
namespace op {
21+
22+
namespace bs {
23+
enum BilinearSamplerOpInputs {kData, kGrid};
24+
enum BilinearSamplerOpOutputs {kOut, kTmp};
25+
}
26+
27+
struct BilinearSamplerParam : public dmlc::Parameter<BilinearSamplerParam> {
28+
DMLC_DECLARE_PARAMETER(BilinearSamplerParam) {
29+
}
30+
};
31+
32+
template<typename xpu, typename DType>
33+
class BilinearSamplerOp : public Operator {
34+
public:
35+
explicit BilinearSamplerOp(BilinearSamplerParam p) {
36+
this->param_ = p;
37+
}
38+
39+
virtual void Forward(const OpContext &ctx,
40+
const std::vector<TBlob> &in_data,
41+
const std::vector<OpReqType> &req,
42+
const std::vector<TBlob> &out_data,
43+
const std::vector<TBlob> &aux_args) {
44+
using namespace mshadow;
45+
using namespace mshadow::expr;
46+
CHECK_EQ(in_data.size(), 2);
47+
Stream<xpu> *s = ctx.get_stream<xpu>();
48+
49+
Tensor<xpu, 4, DType> data = in_data[bs::kData].get<xpu, 4, DType>(s);
50+
Tensor<xpu, 4, DType> grid = in_data[bs::kGrid].get<xpu, 4, DType>(s);
51+
Tensor<xpu, 4, DType> out = out_data[bs::kOut].get<xpu, 4, DType>(s);
52+
53+
BilinearSamplerForward(out, data, grid);
54+
}
55+
56+
virtual void Backward(const OpContext &ctx,
57+
const std::vector<TBlob> &out_grad,
58+
const std::vector<TBlob> &in_data,
59+
const std::vector<TBlob> &out_data,
60+
const std::vector<OpReqType> &req,
61+
const std::vector<TBlob> &in_grad,
62+
const std::vector<TBlob> &aux_args) {
63+
using namespace mshadow;
64+
using namespace mshadow::expr;
65+
CHECK_EQ(in_data.size(), 2);
66+
Stream<xpu> *s = ctx.get_stream<xpu>();
67+
68+
Tensor<xpu, 4, DType> data = in_data[bs::kData].get<xpu, 4, DType>(s);
69+
Tensor<xpu, 4, DType> grid = in_data[bs::kGrid].get<xpu, 4, DType>(s);
70+
Tensor<xpu, 4, DType> gdata = in_grad[bs::kData].get<xpu, 4, DType>(s);
71+
Tensor<xpu, 4, DType> ggrid = in_grad[bs::kGrid].get<xpu, 4, DType>(s);
72+
Tensor<xpu, 4, DType> grad = out_grad[bs::kOut].get<xpu, 4, DType>(s);
73+
gdata = 0.0f;
74+
ggrid = 0.0f;
75+
BilinearSamplerBackward(gdata, ggrid, grad, data, grid);
76+
}
77+
78+
private:
79+
BilinearSamplerParam param_;
80+
}; // class BilinearSamplerOp
81+
82+
template<typename xpu>
83+
Operator* CreateOp(BilinearSamplerParam param, int dtype);
84+
85+
#if DMLC_USE_CXX11
86+
class BilinearSamplerProp : public OperatorProperty {
87+
public:
88+
int NumVisibleOutputs() const override {
89+
return 1;
90+
}
91+
92+
int NumOutputs() const override {
93+
return 2;
94+
}
95+
96+
std::vector<std::string> ListArguments() const override {
97+
return {"data", "grid"};
98+
}
99+
100+
std::vector<std::string> ListOutputs() const override {
101+
return {"output", "tmp"};
102+
}
103+
104+
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
105+
param_.Init(kwargs);
106+
}
107+
108+
std::map<std::string, std::string> GetParams() const override {
109+
return param_.__DICT__();
110+
}
111+
112+
bool InferShape(std::vector<TShape> *in_shape,
113+
std::vector<TShape> *out_shape,
114+
std::vector<TShape> *aux_shape) const override {
115+
using namespace mshadow;
116+
CHECK_EQ(in_shape->size(), 2) << "Input:[data, grid]";
117+
const TShape &dshape = (*in_shape)[bs::kData];
118+
const TShape &lshape = (*in_shape)[bs::kGrid];
119+
if (dshape.ndim() == 0) return false;
120+
CHECK_EQ(dshape.ndim(), 4) \
121+
<< "input data should be 4D in batch-num_filter-y-x";
122+
if (lshape.ndim() == 0) return false;
123+
CHECK_EQ(lshape.ndim(), 4) \
124+
<< "Sampler grid should be 4D in batch-2-y-x";
125+
CHECK_EQ(dshape[0], lshape[0]);
126+
CHECK_EQ(lshape[1], 2) << "incorrect grid shape[1], should be 2";
127+
// target height
128+
CHECK_GT(lshape[2], 0) \
129+
<< "incorrect grid_shape: " << lshape[2];
130+
// target width
131+
CHECK_GT(lshape[3], 0) \
132+
<< "incorrect grid_shape: " << lshape[3];
133+
out_shape->clear();
134+
// output_shape : (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3])
135+
out_shape->push_back(dshape);
136+
(*out_shape)[bs::kOut][2] = lshape[2];
137+
(*out_shape)[bs::kOut][3] = lshape[3];
138+
out_shape->push_back(Shape4(lshape[0], lshape[2], lshape[3], 2));
139+
return true;
140+
}
141+
142+
bool InferType(std::vector<int> *in_type,
143+
std::vector<int> *out_type,
144+
std::vector<int> *aux_type) const override {
145+
int dtype = -1;
146+
for (size_t i = 0; i < in_type->size(); ++i) {
147+
if (dtype == -1) {
148+
dtype = in_type->at(i);
149+
} else {
150+
CHECK(in_type->at(i) == dtype ||
151+
in_type->at(i) == -1) <<
152+
"Non-uniform data type in BilinearSampler";
153+
}
154+
}
155+
if (dtype == -1) {
156+
LOG(FATAL) << "Not enough information to infer type in BilinearSampler.";
157+
return false;
158+
}
159+
size_t nin = this->ListArguments().size();
160+
in_type->clear();
161+
for (size_t i = 0; i < nin; ++i) in_type->push_back(dtype);
162+
size_t naux = this->ListAuxiliaryStates().size();
163+
aux_type->clear();
164+
for (size_t i = 0; i < naux; ++i) aux_type->push_back(dtype);
165+
size_t nout = this->ListOutputs().size();
166+
out_type->clear();
167+
for (size_t i = 0; i < nout; ++i) out_type->push_back(dtype);
168+
return true;
169+
}
170+
171+
OperatorProperty* Copy() const override {
172+
auto ptr = new BilinearSamplerProp();
173+
ptr->param_ = param_;
174+
return ptr;
175+
}
176+
177+
std::string TypeString() const override {
178+
return "BilinearSampler";
179+
}
180+
181+
std::vector<int> DeclareBackwardDependency(
182+
const std::vector<int> &out_grad,
183+
const std::vector<int> &in_data,
184+
const std::vector<int> &out_data) const override {
185+
return {out_grad[bs::kOut],
186+
in_data[bs::kData],
187+
out_data[bs::kTmp],
188+
in_data[bs::kGrid]};
189+
}
190+
191+
Operator* CreateOperator(Context ctx) const override {
192+
LOG(FATAL) << "Not Implemented.";
193+
return NULL;
194+
}
195+
196+
Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
197+
std::vector<int> *in_type) const override;
198+
199+
private:
200+
BilinearSamplerParam param_;
201+
}; // class BilinearSamplerProp
202+
#endif // DMLC_USE_CXX11
203+
} // namespace op
204+
} // namespace mxnet
205+
#endif // MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_

src/operator/bilinear_sampler.cc

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*!
2+
* Copyright (c) 2017 by Contributors
3+
* \file bilinear_sampler.cc
4+
* \brief
5+
* \author Xu Dong
6+
*/
7+
8+
#include "./bilinear_sampler-inl.h"
9+
10+
namespace mshadow {
11+
template<typename DType>
12+
bool between(DType value, int lowerBound, int upperBound) {
13+
return (value >= lowerBound && value <= upperBound);
14+
}
15+
template<typename DType>
16+
inline void BilinearSamplerForward(const Tensor<cpu, 4, DType> &output,
17+
const Tensor<cpu, 4, DType> &input,
18+
const Tensor<cpu, 4, DType> &grid_src) {
19+
DType *out = output.dptr_;
20+
const DType *data = input.dptr_;
21+
const DType *grid = grid_src.dptr_;
22+
int o_n = output.size(0), o_c = output.size(1), o_h = output.size(2), o_w = output.size(3);
23+
int i_c = input.size(1), i_h = input.size(2), i_w = input.size(3);
24+
for (index_t n = 0; n < o_n; ++n) {
25+
for (index_t c = 0; c < o_c; ++c) {
26+
for (index_t h = 0; h < o_h; ++h) {
27+
for (index_t w = 0; w < o_w; ++w) {
28+
index_t out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
29+
index_t grid_index = n * o_h * o_w * 2 + h * o_w + w;
30+
DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2;
31+
DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2;
32+
index_t top_left_y = static_cast<int>(floor(y_real));
33+
index_t top_left_x = static_cast<int>(floor(x_real));
34+
DType top_left_y_w = 1.0 - (y_real - top_left_y);
35+
DType top_left_x_w = 1.0 - (x_real - top_left_x);
36+
index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w +
37+
top_left_y * i_w + top_left_x;
38+
DType top_left_v = 0;
39+
DType top_right_v = 0;
40+
DType bottom_left_v = 0;
41+
DType bottom_right_v = 0;
42+
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1))
43+
top_left_v = *(data + data_index);
44+
if (between(top_left_x + 1, 0, i_w-1) && between(top_left_y, 0, i_h-1))
45+
top_right_v = *(data + data_index + 1);
46+
if (between(top_left_x, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1))
47+
bottom_left_v = *(data + data_index + i_w);
48+
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1))
49+
bottom_right_v = *(data + data_index + i_w + 1);
50+
*(out+out_index) = top_left_v * top_left_y_w * top_left_x_w +
51+
top_right_v * top_left_y_w * (1.0 - top_left_x_w) +
52+
bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w +
53+
bottom_right_v * (1.0 - top_left_y_w) * (1.0 - top_left_x_w);
54+
}
55+
}
56+
}
57+
}
58+
}
59+
60+
template<typename DType>
61+
inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata,
62+
const Tensor<cpu, 4, DType> &ggrid,
63+
const Tensor<cpu, 4, DType> &output_grad,
64+
const Tensor<cpu, 4, DType> &input_data,
65+
const Tensor<cpu, 4, DType> &grid) {
66+
DType *g_input = gdata.dptr_;
67+
DType *grad_grid = ggrid.dptr_;
68+
const DType *grid_src = grid.dptr_;
69+
const DType *grad = output_grad.dptr_;
70+
const DType *data = input_data.dptr_;
71+
int o_n = output_grad.size(0), o_c = output_grad.size(1),
72+
o_h = output_grad.size(2), o_w = output_grad.size(3);
73+
int i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3);
74+
for (index_t n = 0; n < o_n; ++n) {
75+
for (index_t h = 0; h < o_h; ++h) {
76+
for (index_t w = 0; w < o_w; ++w) {
77+
DType top_left_y_gw = 0.0;
78+
DType top_left_x_gw = 0.0;
79+
index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
80+
DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2;
81+
DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2;
82+
index_t top_left_y = static_cast<int>(floor(y_real));
83+
index_t top_left_x = static_cast<int>(floor(x_real));
84+
DType top_left_y_w = 1.0 - (y_real - top_left_y);
85+
DType top_left_x_w = 1.0 - (x_real - top_left_x);
86+
for (index_t c = 0; c < o_c; ++c) {
87+
index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
88+
index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w
89+
+ top_left_x;
90+
// calc 4 vertex value in input data
91+
DType top_left_v = 0;
92+
DType top_right_v = 0;
93+
DType bottom_left_v = 0;
94+
DType bottom_right_v = 0;
95+
// calc input grad
96+
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
97+
*(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
98+
top_left_v = *(data + data_index);
99+
}
100+
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
101+
*(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w
102+
* (1.0 - top_left_x_w);
103+
top_right_v = *(data + data_index + 1);
104+
}
105+
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
106+
*(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w)
107+
* top_left_x_w;
108+
bottom_left_v = *(data + data_index + i_w);
109+
}
110+
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
111+
*(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w)
112+
* (1.0 - top_left_x_w);
113+
bottom_right_v = *(data + data_index + i_w + 1);
114+
}
115+
// calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
116+
top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v +
117+
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
118+
* top_left_x_w);
119+
top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v +
120+
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
121+
* top_left_y_w);
122+
}
123+
// calc grad of grid
124+
*(grad_grid + grid_src_index + o_h * o_w) = top_left_y_gw * (i_h - 1) / 2;
125+
*(grad_grid + grid_src_index) = top_left_x_gw * (i_w - 1) / 2;
126+
}
127+
}
128+
}
129+
}
130+
} // namespace mshadow
131+
132+
namespace mxnet {
133+
namespace op {
134+
template<>
135+
Operator* CreateOp<cpu>(BilinearSamplerParam param, int dtype) {
136+
Operator *op = NULL;
137+
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
138+
op = new BilinearSamplerOp<cpu, DType>(param);
139+
})
140+
return op;
141+
}
142+
143+
Operator *BilinearSamplerProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
144+
std::vector<int> *in_type) const {
145+
std::vector<TShape> out_shape, aux_shape;
146+
std::vector<int> out_type, aux_type;
147+
CHECK(InferType(in_type, &out_type, &aux_type));
148+
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
149+
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
150+
}
151+
152+
DMLC_REGISTER_PARAMETER(BilinearSamplerParam);
153+
154+
MXNET_REGISTER_OP_PROPERTY(BilinearSampler, BilinearSamplerProp)
155+
.add_argument("data", "Symbol", "Input data to the BilinearsamplerOp.")
156+
.add_argument("grid", "Symbol", "Input grid to the BilinearsamplerOp."
157+
"grid has two channels: x_src, y_src")
158+
.add_arguments(BilinearSamplerParam::__FIELDS__())
159+
.describe("Apply bilinear sampling to input feature map.\n "
160+
"output[batch, channel, y_dst, x_dst] = G(data[batch, channel, y_src, x_src)\n "
161+
"x_dst, y_dst enumerate all spatial locations in output\n "
162+
"x_src = grid[batch, 0, y_dst, x_dst]\n "
163+
"y_src = grid[batch, 1, y_dst, x_dst]\n "
164+
"G() denotes the bilinear interpolation kernel\n"
165+
"If (x_src, y_src) is beyond the boundaries of input data,"
166+
"the results of forward and backward are zeros.\n"
167+
"The shape of output will be (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3])\n"
168+
"The operator assumes that grid has been nomalized. "
169+
"If you want to design a CustomOp to manipulate grid, "
170+
"please refer to GridGeneratorOp.");
171+
} // namespace op
172+
} // namespace mxnet

0 commit comments

Comments
 (0)