|
| 1 | +/*! |
| 2 | + * Copyright (c) 2017 by Contributors |
| 3 | + * \file bilinear_sampler.cc |
| 4 | + * \brief |
| 5 | + * \author Xu Dong |
| 6 | +*/ |
| 7 | + |
| 8 | +#include "./bilinear_sampler-inl.h" |
| 9 | + |
| 10 | +namespace mshadow { |
| 11 | +template<typename DType> |
| 12 | +bool between(DType value, int lowerBound, int upperBound) { |
| 13 | + return (value >= lowerBound && value <= upperBound); |
| 14 | +} |
| 15 | +template<typename DType> |
| 16 | +inline void BilinearSamplerForward(const Tensor<cpu, 4, DType> &output, |
| 17 | + const Tensor<cpu, 4, DType> &input, |
| 18 | + const Tensor<cpu, 4, DType> &grid_src) { |
| 19 | + DType *out = output.dptr_; |
| 20 | + const DType *data = input.dptr_; |
| 21 | + const DType *grid = grid_src.dptr_; |
| 22 | + int o_n = output.size(0), o_c = output.size(1), o_h = output.size(2), o_w = output.size(3); |
| 23 | + int i_c = input.size(1), i_h = input.size(2), i_w = input.size(3); |
| 24 | + for (index_t n = 0; n < o_n; ++n) { |
| 25 | + for (index_t c = 0; c < o_c; ++c) { |
| 26 | + for (index_t h = 0; h < o_h; ++h) { |
| 27 | + for (index_t w = 0; w < o_w; ++w) { |
| 28 | + index_t out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; |
| 29 | + index_t grid_index = n * o_h * o_w * 2 + h * o_w + w; |
| 30 | + DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2; |
| 31 | + DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2; |
| 32 | + index_t top_left_y = static_cast<int>(floor(y_real)); |
| 33 | + index_t top_left_x = static_cast<int>(floor(x_real)); |
| 34 | + DType top_left_y_w = 1.0 - (y_real - top_left_y); |
| 35 | + DType top_left_x_w = 1.0 - (x_real - top_left_x); |
| 36 | + index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + |
| 37 | + top_left_y * i_w + top_left_x; |
| 38 | + DType top_left_v = 0; |
| 39 | + DType top_right_v = 0; |
| 40 | + DType bottom_left_v = 0; |
| 41 | + DType bottom_right_v = 0; |
| 42 | + if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) |
| 43 | + top_left_v = *(data + data_index); |
| 44 | + if (between(top_left_x + 1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) |
| 45 | + top_right_v = *(data + data_index + 1); |
| 46 | + if (between(top_left_x, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1)) |
| 47 | + bottom_left_v = *(data + data_index + i_w); |
| 48 | + if (between(top_left_x+1, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1)) |
| 49 | + bottom_right_v = *(data + data_index + i_w + 1); |
| 50 | + *(out+out_index) = top_left_v * top_left_y_w * top_left_x_w + |
| 51 | + top_right_v * top_left_y_w * (1.0 - top_left_x_w) + |
| 52 | + bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w + |
| 53 | + bottom_right_v * (1.0 - top_left_y_w) * (1.0 - top_left_x_w); |
| 54 | + } |
| 55 | + } |
| 56 | + } |
| 57 | + } |
| 58 | +} |
| 59 | + |
| 60 | +template<typename DType> |
| 61 | +inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata, |
| 62 | + const Tensor<cpu, 4, DType> &ggrid, |
| 63 | + const Tensor<cpu, 4, DType> &output_grad, |
| 64 | + const Tensor<cpu, 4, DType> &input_data, |
| 65 | + const Tensor<cpu, 4, DType> &grid) { |
| 66 | + DType *g_input = gdata.dptr_; |
| 67 | + DType *grad_grid = ggrid.dptr_; |
| 68 | + const DType *grid_src = grid.dptr_; |
| 69 | + const DType *grad = output_grad.dptr_; |
| 70 | + const DType *data = input_data.dptr_; |
| 71 | + int o_n = output_grad.size(0), o_c = output_grad.size(1), |
| 72 | + o_h = output_grad.size(2), o_w = output_grad.size(3); |
| 73 | + int i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3); |
| 74 | + for (index_t n = 0; n < o_n; ++n) { |
| 75 | + for (index_t h = 0; h < o_h; ++h) { |
| 76 | + for (index_t w = 0; w < o_w; ++w) { |
| 77 | + DType top_left_y_gw = 0.0; |
| 78 | + DType top_left_x_gw = 0.0; |
| 79 | + index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w; |
| 80 | + DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2; |
| 81 | + DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2; |
| 82 | + index_t top_left_y = static_cast<int>(floor(y_real)); |
| 83 | + index_t top_left_x = static_cast<int>(floor(x_real)); |
| 84 | + DType top_left_y_w = 1.0 - (y_real - top_left_y); |
| 85 | + DType top_left_x_w = 1.0 - (x_real - top_left_x); |
| 86 | + for (index_t c = 0; c < o_c; ++c) { |
| 87 | + index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; |
| 88 | + index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w |
| 89 | + + top_left_x; |
| 90 | + // calc 4 vertex value in input data |
| 91 | + DType top_left_v = 0; |
| 92 | + DType top_right_v = 0; |
| 93 | + DType bottom_left_v = 0; |
| 94 | + DType bottom_right_v = 0; |
| 95 | + // calc input grad |
| 96 | + if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { |
| 97 | + *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w; |
| 98 | + top_left_v = *(data + data_index); |
| 99 | + } |
| 100 | + if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { |
| 101 | + *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w |
| 102 | + * (1.0 - top_left_x_w); |
| 103 | + top_right_v = *(data + data_index + 1); |
| 104 | + } |
| 105 | + if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { |
| 106 | + *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w) |
| 107 | + * top_left_x_w; |
| 108 | + bottom_left_v = *(data + data_index + i_w); |
| 109 | + } |
| 110 | + if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { |
| 111 | + *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w) |
| 112 | + * (1.0 - top_left_x_w); |
| 113 | + bottom_right_v = *(data + data_index + i_w + 1); |
| 114 | + } |
| 115 | + // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src |
| 116 | + top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v + |
| 117 | + (top_left_v - top_right_v - bottom_left_v + bottom_right_v) |
| 118 | + * top_left_x_w); |
| 119 | + top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v + |
| 120 | + (top_left_v - top_right_v - bottom_left_v + bottom_right_v) |
| 121 | + * top_left_y_w); |
| 122 | + } |
| 123 | + // calc grad of grid |
| 124 | + *(grad_grid + grid_src_index + o_h * o_w) = top_left_y_gw * (i_h - 1) / 2; |
| 125 | + *(grad_grid + grid_src_index) = top_left_x_gw * (i_w - 1) / 2; |
| 126 | + } |
| 127 | + } |
| 128 | + } |
| 129 | + } |
| 130 | +} // namespace mshadow |
| 131 | + |
| 132 | +namespace mxnet { |
| 133 | +namespace op { |
| 134 | +template<> |
| 135 | +Operator* CreateOp<cpu>(BilinearSamplerParam param, int dtype) { |
| 136 | + Operator *op = NULL; |
| 137 | + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { |
| 138 | + op = new BilinearSamplerOp<cpu, DType>(param); |
| 139 | + }) |
| 140 | + return op; |
| 141 | +} |
| 142 | + |
| 143 | +Operator *BilinearSamplerProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape, |
| 144 | + std::vector<int> *in_type) const { |
| 145 | + std::vector<TShape> out_shape, aux_shape; |
| 146 | + std::vector<int> out_type, aux_type; |
| 147 | + CHECK(InferType(in_type, &out_type, &aux_type)); |
| 148 | + CHECK(InferShape(in_shape, &out_shape, &aux_shape)); |
| 149 | + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); |
| 150 | +} |
| 151 | + |
| 152 | +DMLC_REGISTER_PARAMETER(BilinearSamplerParam); |
| 153 | + |
| 154 | +MXNET_REGISTER_OP_PROPERTY(BilinearSampler, BilinearSamplerProp) |
| 155 | +.add_argument("data", "Symbol", "Input data to the BilinearsamplerOp.") |
| 156 | +.add_argument("grid", "Symbol", "Input grid to the BilinearsamplerOp." |
| 157 | + "grid has two channels: x_src, y_src") |
| 158 | +.add_arguments(BilinearSamplerParam::__FIELDS__()) |
| 159 | +.describe("Apply bilinear sampling to input feature map.\n " |
| 160 | +"output[batch, channel, y_dst, x_dst] = G(data[batch, channel, y_src, x_src)\n " |
| 161 | +"x_dst, y_dst enumerate all spatial locations in output\n " |
| 162 | +"x_src = grid[batch, 0, y_dst, x_dst]\n " |
| 163 | +"y_src = grid[batch, 1, y_dst, x_dst]\n " |
| 164 | +"G() denotes the bilinear interpolation kernel\n" |
| 165 | +"If (x_src, y_src) is beyond the boundaries of input data," |
| 166 | +"the results of forward and backward are zeros.\n" |
| 167 | +"The shape of output will be (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3])\n" |
| 168 | +"The operator assumes that grid has been nomalized. " |
| 169 | +"If you want to design a CustomOp to manipulate grid, " |
| 170 | +"please refer to GridGeneratorOp."); |
| 171 | +} // namespace op |
| 172 | +} // namespace mxnet |
0 commit comments