diff --git a/docs/dev/onnx_operators.rst b/docs/dev/onnx_operators.rst index fc621b4f894..7d58431ee39 100644 --- a/docs/dev/onnx_operators.rst +++ b/docs/dev/onnx_operators.rst @@ -697,8 +697,11 @@ Operator Support Matrix | | | | functions are | | | | | not enabled | +--------------------------+-----------+-----------------+------------------------------+ -| RoiAlign | ✅ | FP8, FP16, | | -| | | FP32, FP64 | | +| RoiAlign | ✅ | FP8, FP16, | ``X``, | +| | | FP32, FP64, | ``ROI`` take any floating- | +| | | UINT8, UINT16, | point type; | +| | | UINT32, UINT64, | ``batch_indices`` | +| | | | takes any integral type | +--------------------------+-----------+-----------------+------------------------------+ | Round | ✅ | FP8, FP16, | | | | | FP32, FP64 | | diff --git a/src/include/migraphx/op/roialign.hpp b/src/include/migraphx/op/roialign.hpp index d66e8f0feeb..76b7c8b967e 100644 --- a/src/include/migraphx/op/roialign.hpp +++ b/src/include/migraphx/op/roialign.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -74,6 +74,15 @@ struct roialign auto type = inputs.at(0).type(); // check input correct + if(shape::is_integral(type)) + MIGRAPHX_THROW("ROIALIGN: incorrect type for input data! (should be non-integer)"); + if(shape::is_integral(inputs.at(1).type())) + MIGRAPHX_THROW("ROIALIGN: incorrect data type for rois! (should be non-integer)"); + if(!shape::is_integral(inputs.at(2).type())) + MIGRAPHX_THROW( + "ROIALIGN: incorrect datatype for roi indices! (should be an integral type)"); + if(x_lens.size() != 4) + MIGRAPHX_THROW("ROIALIGN: data input must have 4 dimensions n, c, h, w"); if(bi_lens.size() != 1) { MIGRAPHX_THROW("ROIALIGN: batch indices should be 1 dimension!"); @@ -92,8 +101,8 @@ struct roialign std::vector out_lens = x_lens; out_lens[0] = roi_lens[0]; - out_lens[2] = output_height; - out_lens[3] = output_width; + out_lens[2] = output_width; + out_lens[3] = output_height; return {type, out_lens}; } @@ -115,17 +124,22 @@ struct roialign std::vector results(bin_grid_size[0] * bin_grid_size[1] * output_height * output_width); shape_for_each(comp_s, [&](const auto& idx_v, size_t index) { - std::array p = {idx_v[0], idx_v[1]}; - std::array i = {idx_v[2], idx_v[3]}; + // The p and i indexes correspond to nested looping parameters in ORT that go in y, x + // order. The i[x] value is least significant and iterates the fastest. + std::array p = {idx_v[1], idx_v[0]}; + std::array i = {idx_v[3], idx_v[2]}; // these are always equal + // xy is scaled coordinates of start point of ROI std::array xy{}; + // low, high are floor and ceiling of the xy value (i.e. the bounds of the pixel it lies + // inside) from which we will interpolate. std::array low{}; std::array high{}; for(auto ii : range(p.size())) { xy[ii] = roi_start[ii] + p[ii] * bin_size[ii] + (i[ii] + .5f) * bin_size[ii] / bin_grid_size[ii]; - xy[ii] = (coord_trans_mode == "half_pixel") ? (xy[ii] - 0.5f) : xy[ii]; + if(xy[ii] < -1.0 or xy[ii] > dims[ii]) { results[index] = pos_weight{}; @@ -140,21 +154,18 @@ struct roialign xy[ii] = high[ii] = low[ii] = dims[ii] - 1; } } + results[index].pos = {low[1] * dims[0] + low[0], + low[1] * dims[0] + high[0], + high[1] * dims[0] + low[0], + high[1] * dims[0] + high[0]}; - results[index].pos = {low[0] * dims[1] + low[1], - low[0] * dims[1] + high[1], - high[0] * dims[1] + low[1], - high[0] * dims[1] + high[1]}; - - float ly = xy[0] - low[0]; - float lx = xy[1] - low[1]; + float lx = xy[0] - low[0]; + float ly = xy[1] - low[1]; float hy = 1.0f - ly; float hx = 1.0f - lx; - - // save weights and indeces + // save weights and indices results[index].w = {hy * hx, hy * lx, ly * hx, ly * lx}; }); - return results; } @@ -176,11 +187,12 @@ struct roialign double final(double x, std::size_t y) { return (y == 0) ? 0.0 : (x / y); } }; + // Calculate a pooling value for 1 block of bin_grid_size*bin_grid_size weights template std::tuple calc_pooling(const T& data, const std::array& bin_grid_size, const std::vector& pos_weights, - int64_t index, + int64_t index, // index to c Op op) const { double output_val = op.init(); @@ -208,11 +220,11 @@ struct roialign int64_t n_rois = out_lens[0]; std::size_t channels = out_lens[1]; // output dims of height and width, in all 2-dim arrays, the first dim - // is for height and second dim is for width + // is for height and second dim is for width i.e. (y, x) order std::array out_dims = {out_lens[2], out_lens[3]}; const auto& x_lens = args.at(0).get_shape().lens(); // input dims of height and width - std::array in_dims = {x_lens[2], x_lens[3]}; + std::array in_dims = {x_lens[3], x_lens[2]}; auto roi_s = args.at(1).get_shape(); visit_all(result, args.at(0), args.at(1))([&](auto output, auto x, auto roi) { @@ -220,15 +232,17 @@ struct roialign par_for(n_rois, [&](auto n) { const auto bottom_data = x.begin(); const auto roi_batch_ind = batch_indices[n]; - // Do not using rounding; this implementation detail is critical + // Do not use rounding here even if data is a quantized type; this + // implementation detail is critical + const float offset = (coord_trans_mode == "half_pixel") ? 0.5 : 0.0; std::array roi_starts = { - static_cast(roi[roi_s.index({n, 1})] * spatial_scale), - static_cast(roi[roi_s.index({n, 0})] * spatial_scale)}; + static_cast(roi[roi_s.index({n, 0})] * spatial_scale - offset), + static_cast(roi[roi_s.index({n, 1})] * spatial_scale - offset)}; std::array roi_ends = { - static_cast(roi[roi_s.index({n, 3})] * spatial_scale), - static_cast(roi[roi_s.index({n, 2})] * spatial_scale)}; + static_cast(roi[roi_s.index({n, 2})] * spatial_scale - offset), + static_cast(roi[roi_s.index({n, 3})] * spatial_scale - offset)}; - // Force malformed ROIs to be 1x1 + // Force malformed ROIs to be 1x1, if in output_half_pixel transform mode std::array roi_size{}; std::array bin_size{}; std::array bin_grid_size{}; @@ -236,8 +250,8 @@ struct roialign for(auto ii : range(roi_size.size())) { roi_size[ii] = roi_ends[ii] - roi_starts[ii]; - roi_size[ii] = std::max(roi_size[ii], 1.0f); - + if(coord_trans_mode != "half_pixel") + roi_size[ii] = std::max(roi_size[ii], 1.0f); bin_size[ii] = roi_size[ii] / out_dims[ii]; bin_grid_size[ii] = (sampling_ratio > 0) ? sampling_ratio @@ -247,7 +261,7 @@ struct roialign // we want to precalculate indices and weights shared by all channels, // this is the key point of optimization std::vector comp_lens = { - out_dims[0], out_dims[1], bin_grid_size[0], bin_grid_size[1]}; + out_dims[1], out_dims[0], bin_grid_size[1], bin_grid_size[0]}; shape comp_s{shape::float_type, comp_lens}; auto pre_calc = this->calc_pos_weight(in_dims, comp_s, roi_starts, bin_size, bin_grid_size); @@ -255,14 +269,16 @@ struct roialign std::vector comp_lens1 = {channels, out_dims[0], out_dims[1]}; shape comp_s1{migraphx::shape::float_type, comp_lens1}; std::vector vec_index(channels, 0); + shape_for_each(comp_s1, [&](const auto& idx) { - auto c = idx[0]; + auto c = idx[0]; // channel count auto ph = idx[1]; auto pw = idx[2]; const auto offset_bottom_data = bottom_data + static_cast((roi_batch_ind * channels + c) * in_dims[0] * in_dims[1]); + double output_val; std::tie(output_val, vec_index[c]) = (mode == migraphx::op::pooling_mode::average) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp index b7d7216c690..769c7c978bf 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp @@ -24,6 +24,8 @@ #ifndef MIGRAPHX_GUARD_KERNELS_ROIALIGN_HPP #define MIGRAPHX_GUARD_KERNELS_ROIALIGN_HPP +// #include +// #include #include #include #include @@ -87,13 +89,14 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( xy[ii] = high[ii] = low[ii] = dims[ii] - 1; } } - array locs = {low[0] * dims[1] + low[1], - low[0] * dims[1] + high[1], - high[0] * dims[1] + low[1], - high[0] * dims[1] + high[1]}; + array locs = {low[1] * dims[0] + low[0], + low[1] * dims[0] + high[0], + high[1] * dims[0] + low[0], + high[1] * dims[0] + high[0]}; + + float lx = xy[0] - low[0]; + float ly = xy[1] - low[1]; - float ly = xy[0] - low[0]; - float lx = xy[1] - low[1]; float hy = 1.0f - ly; float hx = 1.0f - lx; // do calculations in floating point and convert final result to required type @@ -104,6 +107,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( return implicit_conversion(pooling(v01, v23)); } +// Calculate a single pooled output value template MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data, const array& roi_starts, @@ -111,17 +115,15 @@ MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data, const array& idx, const array& bin_grid_size, const array& dims, - float roi_offset, Op op) { + // for one idx (output height and width coordinates) we iterate through all bin_grid values using in_dtype = typename Iterator::value_type; in_dtype output_val = in_dtype{op.init()}; const int64_t count = bin_grid_size[0] * bin_grid_size[1]; dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) { array id = {iy, ix}; - array locs = - roi_starts + idx * bin_size + bin_size * (id + 0.5f) / bin_grid_size + roi_offset; - + array locs = roi_starts + idx * bin_size + bin_size * (id + 0.5f) / bin_grid_size; auto val = bilinear_interpolate(data, dims, locs, op); output_val = op(output_val, val); }); @@ -155,7 +157,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, auto channel_num = x_lens[1]; // input dims of height and width, in all 2-dim arrays, the first dim // is for height and second dim is for width - array in_dims = {x_lens[2], x_lens[3]}; + array in_dims = {x_lens[3], x_lens[2]}; const auto stride = index.nglobal(); auto out_s = y_t.get_shape(); @@ -166,6 +168,17 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, const auto& out_lens = out_s.lens; array out_dims = {out_lens[2], out_lens[3]}; + // Compute lens and strides vectors for use in reindexing output. + // Todo: look for a less indirect way to reconcile the ordering of iteration + // between this op. and the reference. + array m_lens{out_lens[0], out_lens[1], out_lens[3], out_lens[2]}; + array m_strides; + m_strides[3] = 1; + for(int k = 2; k >= 0; k--) + { + m_strides[k] = m_strides[k + 1] * m_lens[k + 1]; + } + for(index_int i = index.global; i < out_s.elements(); i += stride) { auto idx = out_s.multi(i); @@ -177,12 +190,17 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, const auto offset_rois = rois + (n * roi_column_num); const int batch_ind = ind[n]; + // Note that roi_offset in src/targets/gpu/jit/roialign.cpp uses a negative value, so we add + // rather than subtract it here array roi_starts = { - static_cast(offset_rois[1]) * static_cast(s.spatial_scale), - static_cast(offset_rois[0]) * static_cast(s.spatial_scale)}; + static_cast(offset_rois[0]) * static_cast(s.spatial_scale) + s.roi_offset, + static_cast(offset_rois[1]) * static_cast(s.spatial_scale) + + s.roi_offset}; + array roi_ends = { - static_cast(offset_rois[3]) * static_cast(s.spatial_scale), - static_cast(offset_rois[2]) * static_cast(s.spatial_scale)}; + static_cast(offset_rois[2]) * static_cast(s.spatial_scale) + s.roi_offset, + static_cast(offset_rois[3]) * static_cast(s.spatial_scale) + + s.roi_offset}; array roi_size{}; array bin_size{}; @@ -191,36 +209,37 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, for(index_int ii = 0; ii < roi_size.size(); ++ii) { roi_size[ii] = roi_ends[ii] - roi_starts[ii]; - roi_size[ii] = migraphx::max(roi_size[ii], 1.0f); + if(s.roi_offset == 0.f) + roi_size[ii] = migraphx::max(roi_size[ii], 1.0f); bin_size[ii] = roi_size[ii] / out_dims[ii]; bin_grid_size[ii] = (s.sampling_ratio > 0) ? s.sampling_ratio : migraphx::ceil(roi_size[ii] / out_dims[ii]); } - const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]); + + // + // Reindexing. Calculations to this point did not iterate in the same order as + // in the reference op; we now calculate the output index corresponding to i + // + size_t pp = i; + size_t jj = (pp / m_strides[0]) * m_strides[0]; + pp = pp % m_strides[0]; + jj += (pp / m_strides[1]) * m_strides[1]; + pp %= m_strides[1]; + pp = pp / m_lens[2] + (pp % m_lens[2]) * m_strides[2]; + jj += pp; + if constexpr(s.is_avg_pooling) { - y_t[i] = calc_pooling(offset_x, - roi_starts, - bin_size, - {ph, pw}, - bin_grid_size, - in_dims, - s.roi_offset, - avg_pool{}); + y_t[jj] = calc_pooling( + offset_x, roi_starts, bin_size, {ph, pw}, bin_grid_size, in_dims, avg_pool{}); } else { - y_t[i] = calc_pooling(offset_x, - roi_starts, - bin_size, - {ph, pw}, - bin_grid_size, - in_dims, - s.roi_offset, - max_pool{}); + y_t[jj] = calc_pooling( + offset_x, roi_starts, bin_size, {ph, pw}, bin_grid_size, in_dims, max_pool{}); } } } diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index e945e369ed8..b8e1945e3ea 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -31,10 +31,13 @@ from onnx.numpy_helper import from_array -def onnx_test(external_data=False): +def onnx_test(external_data=False, opset_imports=None): def create_onnx_test(op_test): def run_test(): op_info = op_test() + opset_id = [helper.make_operatorsetid('', opset_imports) + ] if opset_imports is not None else None + if len(op_info) > 3: graph_def = helper.make_graph(op_info[0], op_test.__name__, @@ -45,7 +48,8 @@ def run_test(): graph_def = helper.make_graph(op_info[0], op_test.__name__, op_info[1], op_info[2]) model_def = helper.make_model(graph_def, - producer_name=op_test.__name__) + producer_name=op_test.__name__, + opset_imports=opset_id) onnx.save_model(model_def, '{}.onnx'.format(op_test.__name__), save_as_external_data=external_data, @@ -10587,8 +10591,11 @@ def rnn_r_3arg_layout_test(): return ([node], [seq, w, r], [hs, output]) -@onnx_test() +@onnx_test(opset_imports=16) def roialign_default_test(): + # The op. ROIAlign had an attribute coordinate_transformation_mode added + # as of Onnx opset 16; we make opset-specific test models which give + # different default values. x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 4, 7, 8]) roi = helper.make_tensor_value_info('rois', TensorProto.FLOAT, [8, 4]) bi = helper.make_tensor_value_info('batch_ind', TensorProto.INT64, [8]) @@ -10601,27 +10608,85 @@ def roialign_default_test(): return ([node], [x, roi, bi], [y]) -@onnx_test() -def roialign_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 5, 4, 7]) +@onnx_test(opset_imports=12) +def roialign_default_test_12(): + # Same model as in roialign_default_test() but with an older opset specified + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 4, 7, 8]) roi = helper.make_tensor_value_info('rois', TensorProto.FLOAT, [8, 4]) bi = helper.make_tensor_value_info('batch_ind', TensorProto.INT64, [8]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [8, 4, 5, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [8, 4, 1, 1]) + + node = onnx.helper.make_node('RoiAlign', + inputs=['x', 'rois', 'batch_ind'], + outputs=['y']) + + return ([node], [x, roi, bi], [y]) + +@onnx_test() +def roialign_test(): + # Roialign with output_half_pixel mode is backward-compatible. + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 5, 4, 7]) + roi = helper.make_tensor_value_info('rois', TensorProto.FLOAT, [2, 4]) + bi = helper.make_tensor_value_info('batch_ind', TensorProto.INT64, [2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 4, 5, 5]) node = onnx.helper.make_node( 'RoiAlign', inputs=['x', 'rois', 'batch_ind'], outputs=['y'], spatial_scale=2.0, output_height=5, - output_width=5, + output_width=3, sampling_ratio=3, + # todo: max test mode="avg", coordinate_transformation_mode="output_half_pixel") return ([node], [x, roi, bi], [y]) +@onnx_test() +def roialign_half_pixel_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 4, 3]) + roi = helper.make_tensor_value_info('rois', TensorProto.FLOAT, [2, 4]) + bi = helper.make_tensor_value_info('batch_ind', TensorProto.INT64, [2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 4, 3]) + + # half_pixel is the newer mode for ROIAlign + node = onnx.helper.make_node('RoiAlign', + inputs=['x', 'rois', 'batch_ind'], + outputs=['y'], + spatial_scale=2.0, + output_height=2, + output_width=3, + sampling_ratio=2, + mode="avg", + coordinate_transformation_mode="half_pixel") + + return ([node], [x, roi, bi], [y]) + + +@onnx_test() +def roialign_half_pixel_max_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 4, 3]) + roi = helper.make_tensor_value_info('rois', TensorProto.FLOAT, [2, 4]) + bi = helper.make_tensor_value_info('batch_ind', TensorProto.INT64, [2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 4, 3]) + + # half_pixel is the newer mode for ROIAlign + node = onnx.helper.make_node('RoiAlign', + inputs=['x', 'rois', 'batch_ind'], + outputs=['y'], + spatial_scale=2.0, + output_height=2, + output_width=3, + sampling_ratio=2, + mode="max", + coordinate_transformation_mode="half_pixel") + + return ([node], [x, roi, bi], [y]) + + @onnx_test() def round_half_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [4, 4]) diff --git a/test/onnx/parse/roialign_default_test.cpp b/test/onnx/parse/roialign_default_test.cpp index b4869740a57..9a14778c12c 100644 --- a/test/onnx/parse/roialign_default_test.cpp +++ b/test/onnx/parse/roialign_default_test.cpp @@ -35,17 +35,40 @@ TEST_CASE(roialign_default_test) auto x = mm->add_parameter("x", sx); auto rois = mm->add_parameter("rois", srois); auto bi = mm->add_parameter("batch_ind", sbi); - - // Due to the onnx model using opset 12, the coordinate_transformation_mode should be set to - // output_half_pixel + // Depending on whether the model was built for Onnx opset 16 or earlier, the default + // coordinate_transformation_mode will be different. These model files had explicit opset given + // when they were created. auto r = mm->add_instruction( - migraphx::make_op("roialign", {{"coordinate_transformation_mode", "output_half_pixel"}}), + migraphx::make_op("roialign", {{"coordinate_transformation_mode", "half_pixel"}}), x, rois, bi); mm->add_return({r}); - auto prog = read_onnx("roialign_default_test.onnx"); - EXPECT(p == prog); } + + +TEST_CASE(roialign_default_12_test) +{ + // opset 12 version + migraphx::shape sx{migraphx::shape::float_type, {10, 4, 7, 8}}; + migraphx::shape srois{migraphx::shape::float_type, {8, 4}}; + migraphx::shape sbi{migraphx::shape::int64_type, {8}}; + + // Opset 12 program + migraphx::program p_12; + auto* mm_12 = p_12.get_main_module(); + auto x_12 = mm_12->add_parameter("x", sx); + auto rois_12 = mm_12->add_parameter("rois", srois); + auto bi_12 = mm_12->add_parameter("batch_ind", sbi); + + auto r_12 = mm_12->add_instruction( + migraphx::make_op("roialign", {{"coordinate_transformation_mode", "output_half_pixel"}}), + x_12, + rois_12, + bi_12); + mm_12->add_return({r_12}); + auto prog_12 = read_onnx("roialign_default_test_12.onnx"); + EXPECT(p_12 == prog_12); +} diff --git a/test/onnx/parse/roialign_test.cpp b/test/onnx/parse/roialign_test.cpp index 05f27b6473c..346346727b2 100644 --- a/test/onnx/parse/roialign_test.cpp +++ b/test/onnx/parse/roialign_test.cpp @@ -27,8 +27,8 @@ TEST_CASE(roialign_test) { migraphx::shape sx{migraphx::shape::float_type, {10, 5, 4, 7}}; - migraphx::shape srois{migraphx::shape::float_type, {8, 4}}; - migraphx::shape sbi{migraphx::shape::int64_type, {8}}; + migraphx::shape srois{migraphx::shape::float_type, {2, 4}}; + migraphx::shape sbi{migraphx::shape::int64_type, {2}}; migraphx::program p; auto* mm = p.get_main_module(); @@ -41,7 +41,7 @@ TEST_CASE(roialign_test) {{"coordinate_transformation_mode", "output_half_pixel"}, {"spatial_scale", 2.0f}, {"output_height", 5}, - {"output_width", 5}, + {"output_width", 3}, {"sampling_ratio", 3}}), x, rois, diff --git a/test/onnx/roialign_default_test.onnx b/test/onnx/roialign_default_test.onnx index 4421e17be60..cc47b78b9df 100644 Binary files a/test/onnx/roialign_default_test.onnx and b/test/onnx/roialign_default_test.onnx differ diff --git a/test/onnx/roialign_default_test_12.onnx b/test/onnx/roialign_default_test_12.onnx new file mode 100644 index 00000000000..1747f61ee12 Binary files /dev/null and b/test/onnx/roialign_default_test_12.onnx differ diff --git a/test/onnx/roialign_half_pixel_max_test.onnx b/test/onnx/roialign_half_pixel_max_test.onnx new file mode 100644 index 00000000000..fc192568d6f Binary files /dev/null and b/test/onnx/roialign_half_pixel_max_test.onnx differ diff --git a/test/onnx/roialign_half_pixel_test.onnx b/test/onnx/roialign_half_pixel_test.onnx new file mode 100644 index 00000000000..4b4ff5dcb2f Binary files /dev/null and b/test/onnx/roialign_half_pixel_test.onnx differ diff --git a/test/onnx/roialign_test.onnx b/test/onnx/roialign_test.onnx index f39485530c4..eb6703f49d5 100644 Binary files a/test/onnx/roialign_test.onnx and b/test/onnx/roialign_test.onnx differ diff --git a/test/onnx/verify/roialign_half_pixel_verify_test.cpp b/test/onnx/verify/roialign_half_pixel_verify_test.cpp new file mode 100644 index 00000000000..2653471af1b --- /dev/null +++ b/test/onnx/verify/roialign_half_pixel_verify_test.cpp @@ -0,0 +1,99 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +// The half_pixel mode for the ROIAlign op +TEST_CASE(roialign_half_pixel_verify_test) +{ + migraphx::program p = read_onnx("roialign_half_pixel_test.onnx"); + p.compile(migraphx::make_target("ref")); + migraphx::shape s{migraphx::shape::float_type, {2, 2, 4, 3}}; + std::vector data(2 * 2 * 4 * 3); + std::iota(data.begin(), data.end(), 0.f); + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + pp["y"] = migraphx::argument(s, data.data()); + + migraphx::shape srois{migraphx::shape::float_type, {2, 4}}; + std::vector rois_data = {1.1, 0.73, 1.7, 1.13, 1.1, 0.73, 2.6, 1.13}; + migraphx::shape sbi{migraphx::shape::int64_type, {2}}; // batch_index + std::vector bi_data = {0, 1}; + + pp["rois"] = migraphx::argument(srois, rois_data.data()); + pp["batch_ind"] = migraphx::argument(sbi, bi_data.data()); + pp["y"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + // Gold values were generated with onnxruntime + std::vector gold = {5.38, 5.4799995, 5.4799995, 6.58, 6.68, 6.68, + 17.38, 17.48, 17.48, 18.58, 18.68, 18.68, + 29.454998, 14.74, 0., 30.654999, 15.34, 0., + 41.455, 20.74, 0., 42.655003, 21.34, 0.}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +// The half_pixel mode for the ROIAlign op, max pooling +TEST_CASE(roialign_half_pixel_max_verify_test) +{ + migraphx::program p = read_onnx("roialign_half_pixel_max_test.onnx"); + p.compile(migraphx::make_target("ref")); + migraphx::shape s{migraphx::shape::float_type, {2, 2, 4, 3}}; + std::vector data(2 * 2 * 4 * 3); + std::iota(data.begin(), data.end(), 0.f); + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + pp["y"] = migraphx::argument(s, data.data()); + + migraphx::shape srois{migraphx::shape::float_type, {2, 4}}; + std::vector rois_data = {1.1, 0.73, 1.7, 1.13, 1.1, 0.73, 2.6, 1.13}; + migraphx::shape sbi{migraphx::shape::int64_type, {2}}; // batch_index + std::vector bi_data = {0, 1}; + + pp["rois"] = migraphx::argument(srois, rois_data.data()); + pp["batch_ind"] = migraphx::argument(sbi, bi_data.data()); + pp["y"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + // Gold values were generated with onnxruntime + std::vector gold = { 4.7 , 4.7 , 4.7 ,5.2799997, 5.2799997, 5.2799997, + + 15.979999 , 15.979999 , 15.979999 , 13.199999 , 13.199999 , 13.199999 , + + + 27.477499 , 27.477499 , 0. ,19.440002 , 19.440002 , 0. , + + 38.8475 , 38.8475 , 0. , 26.730003 , 26.730003 , 0. }; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/roialign_verify_test.cpp b/test/onnx/verify/roialign_verify_test.cpp new file mode 100644 index 00000000000..ea9d84e7e8a --- /dev/null +++ b/test/onnx/verify/roialign_verify_test.cpp @@ -0,0 +1,87 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(roialign_verify_test) +{ + migraphx::program p = read_onnx("roialign_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::float_type, {10, 5, 4, 7}}; + std::vector data(10 * 5 * 4 * 7); + std::iota(data.begin(), data.end(), 0); + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + pp["y"] = migraphx::argument(s, data.data()); + + migraphx::shape srois{migraphx::shape::float_type, {2, 4}}; + std::vector rois_data = {0.1, 0.15, 0.6, 0.35, 2.1, 1.73, 3.8, 2.13}; + migraphx::shape sbi{migraphx::shape::int64_type, {2}}; + std::vector bi_data = {1, 0}; + + pp["rois"] = migraphx::argument(srois, rois_data.data()); + pp["batch_ind"] = migraphx::argument(sbi, bi_data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + // gold values were generated with onnxruntime + std::vector gold = { + 143.16667, 143.49998, 143.83333, 144.56667, 144.9, 145.23334, 145.96667, 146.3, + 146.63333, 147.36667, 147.70001, 148.03334, 148.76666, 149.09999, 149.43333, + + 171.16667, 171.5, 171.83333, 172.56667, 172.90001, 173.23334, 173.96667, 174.3, + 174.63333, 175.36667, 175.70001, 176.03333, 176.76666, 177.09999, 177.43335, + + 199.16667, 199.5, 199.83333, 200.56667, 200.90001, 201.23334, 201.96666, 202.3, + 202.63333, 203.36665, 203.70001, 204.03333, 204.76668, 205.09999, 205.43333, + + 227.16667, 227.5, 227.83333, 228.56668, 228.90001, 229.23332, 229.96669, 230.29999, + 230.63333, 231.36664, 231.70001, 232.03334, 232.76668, 233.09999, 233.43332, + + 255.16667, 255.5, 255.83333, 256.56668, 256.90002, 257.2333, 257.96667, 258.3, + 258.63333, 259.36664, 259.69998, 260.03333, 260.7667, 261.09998, 261.43338, + + 25.766665, 26.807405, 9., 25.766665, 26.807405, 9., 17.177776, 17.871605, + 6., 0., 0., 0., 0., 0., 0., + + 53.766666, 54.807407, 18.333334, 53.766666, 54.807407, 18.333334, 35.844444, 36.538273, + 12.222222, 0., 0., 0., 0., 0., 0., + + 81.76667, 82.8074, 27.666666, 81.76667, 82.8074, 27.666666, 54.51111, 55.204937, + 18.444445, 0., 0., 0., 0., 0., 0., + + 109.76667, 110.8074, 37., 109.76667, 110.8074, 37., 73.17777, 73.871605, + 24.666666, 0., 0., 0., 0., 0., 0., + + 137.76666, 138.80742, 46.333332, 137.76666, 138.80742, 46.333332, 91.844444, 92.53828, + 30.88889, 0., 0., 0., 0., 0., 0.}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 8d08455d814..819d05f9556 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -5155,6 +5155,11 @@ TEST_CASE(roialign_test) expect_shape(sout, migraphx::make_op("roialign"), sx, srois, sbi); + // data input must be 4 dimensions + migraphx::shape sx2{migraphx::shape::float_type, {2, 3, 4, 5, 6}}; + throws_shape(migraphx::make_op("roialign"), sx2, srois, sbi); + + // batch index must be 1 dimension migraphx::shape sbi1{migraphx::shape::int64_type, {2, 3}}; throws_shape(migraphx::make_op("roialign"), sx, srois, sbi1); @@ -5166,6 +5171,23 @@ TEST_CASE(roialign_test) migraphx::shape srois2{migraphx::shape::float_type, {2, 3}}; throws_shape(migraphx::make_op("roialign"), sx, srois2, sbi); + + // alternate data types + migraphx::shape sx_d{migraphx::shape::double_type, {3, 4, 5, 6}}; + migraphx::shape srois_d{migraphx::shape::double_type, {2, 4}}; + migraphx::shape sbi_int{migraphx::shape::int32_type, {2}}; + migraphx::shape sout_d{migraphx::shape::double_type, {2, 4, 1, 1}}; + expect_shape(sout_d, migraphx::make_op("roialign"), sx_d, srois_d, sbi_int); + + // wrong data types + migraphx::shape srois_int{migraphx::shape::int32_type, {2, 3}}; + throws_shape(migraphx::make_op("roialign"), sx, srois_int, sbi); + + migraphx::shape sx_int{migraphx::shape::int64_type, {3, 4, 5, 6}}; + throws_shape(migraphx::make_op("roialign"), sx_int, srois, sbi); + + migraphx::shape sbi_float{migraphx::shape::float_type, {2}}; + throws_shape(migraphx::make_op("roialign"), sx, srois, sbi_float); } TEST_CASE(test_concat) diff --git a/test/verify/test_roialign.cpp b/test/verify/test_roialign.cpp index 6314491e10d..88864631e87 100644 --- a/test/verify/test_roialign.cpp +++ b/test/verify/test_roialign.cpp @@ -27,6 +27,39 @@ #include #include +template +struct test_roialign_half_pixel : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape x_s{DType, {5, 7, 2, 2}}; + + migraphx::shape roi_s{DType, {2, 4}}; + + migraphx::shape ind_s{migraphx::shape::int64_type, {2}}; + std::vector ind_vec = {1, 0}; + + auto x = mm->add_parameter("x", x_s); + auto roi = mm->add_parameter("roi", roi_s); + auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec)); + auto r = mm->add_instruction( + migraphx::make_op("roialign", + {{"spatial_scale", 1.1}, + {"output_height", 5}, + {"output_width", 3}, + {"sampling_ratio", 3}, + {"coordinate_transformation_mode", "half_pixel"}}), + x, + roi, + ind); + mm->add_return({r}); + + return p; + } +}; + template struct test_roialign : verify_program> { @@ -44,20 +77,23 @@ struct test_roialign : verify_program> auto x = mm->add_parameter("x", x_s); auto roi = mm->add_parameter("roi", roi_s); auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec)); - auto r = mm->add_instruction(migraphx::make_op("roialign", - {{"spatial_scale", 1.0}, - {"output_height", 5}, - {"output_width", 5}, - {"sampling_ratio", 2}}), - x, - roi, - ind); + auto r = mm->add_instruction( + migraphx::make_op("roialign", + {{"spatial_scale", 1.1}, + {"output_height", 5}, + {"output_width", 2}, + {"sampling_ratio", 2}, + {"coordinate_transformation_mode", "output_half_pixel"}}), + x, + roi, + ind); mm->add_return({r}); return p; } }; +template struct test_roialign_half_pixel; template struct test_roialign; template struct test_roialign; template struct test_roialign;