Skip to content

Commit

Permalink
Add JIT pad (#1411) (#1441)
Browse files Browse the repository at this point in the history
updated GPU pad to now use JIT version.
added range functions for JIT kernels.

Co-authored-by: kahmed10 <[email protected]>
  • Loading branch information
causten and kahmed10 authored Nov 2, 2022
1 parent 360b180 commit 4d471bd
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 1 deletion.
100 changes: 100 additions & 0 deletions src/targets/gpu/jit/pad.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/float_equal.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

using namespace migraphx::gpu::gen; // NOLINT

static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/pad.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void pad_kernel(void* input_p, void* output_p)
{
auto offsets = index_ints<${offsets}>{};
auto idx = make_index();
make_tensors()(input_p, output_p)([&](auto input, auto output) {
pad(idx, offsets, input, output, ${pad_val});
});
}
}
} // namespace migraphx
)__migraphx__";

struct pad_compiler : compiler<pad_compiler>
{
std::vector<std::string> names() const { return {"pad"}; }

operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.kernel_name = "pad_kernel";
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements()));

auto pad_val = v.get("value", 0.f);
auto pad_val_string = to_string(pad_val);
if(float_equal(pad_val, std::numeric_limits<float>::lowest()))
pad_val_string = "lowest{}";
if(float_equal(pad_val, std::numeric_limits<float>::max()))
pad_val_string = "highest{}";

auto padding = v.at("pads").to_vector<int64_t>();
auto input_lens = inputs.front().lens();
std::vector<size_t> offsets(input_lens.size());
std::copy(padding.begin(), padding.begin() + offsets.size(), offsets.begin());

auto src = interpolate_string(
pointwise_kernel,
{{"pad_val", to_string(pad_val_string)}, {"offsets", to_string_range(offsets)}});
return compile_hip_code_object(src, options);
}

compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
63 changes: 63 additions & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_PAD_HPP
#define MIGRAPHX_GUARD_KERNELS_PAD_HPP

#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp>

namespace migraphx {

template <class Offsets, class Input, class Output, class PadVal>
__device__ void pad(const index& idx,
const Offsets& offsets,
const Input& input,
Output& output,
const PadVal& pad_val)
{
auto output_shape = output.get_shape();
idx.global_stride(output_shape.elements(), [&](auto i) {
// 1. get current multi-index for output
// 2. get the size of the input to determine input boundaries
// 3. compute the corresponding multi-index for input by accounting for offsets
// 4. if current multi-index is within offsets or input's new multi-index is out of bounds,
// use pad value instead of input's value
auto multi = output_shape.multi(i);
auto input_bounds = input.get_shape().lens;
auto input_idx = multi - offsets;
auto range_multi = range(multi.size());

if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) {
return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j];
}))
output[multi] = pad_val;
else
output[multi] = input[input_idx];
});
}

} // namespace migraphx
#endif
49 changes: 49 additions & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/ranges.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_RANGES_HPP
#define MIGRAPHX_GUARD_KERNELS_RANGES_HPP

#include <migraphx/kernels/iota_iterator.hpp>

namespace migraphx {

template <class Iterator>
struct iterator_range
{
Iterator start;
Iterator last;

constexpr Iterator begin() const { return start; }

constexpr Iterator end() const { return last; }
};

constexpr iterator_range<iota_iterator> range(diff_int start, diff_int last)
{
return {{start, {}}, {last, {}}};
}
constexpr iterator_range<iota_iterator> range(diff_int last) { return range(0, last); }

} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_RANGES_HPP
1 change: 0 additions & 1 deletion src/targets/gpu/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ struct miopen_apply
add_extend_op("lrn");
add_extend_op("multinomial");
add_extend_op("nonzero");
add_extend_op("pad");
add_extend_op("pooling");
add_extend_op("prefix_scan_sum");
add_extend_op("reverse");
Expand Down
15 changes: 15 additions & 0 deletions test/ref_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3884,6 +3884,21 @@ TEST_CASE(pad_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}

TEST_CASE(pad_test_asym)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l0 = mm->add_literal(migraphx::literal{s, {1, 2, 3, 4}});
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 1, 1}}}), l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(9);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 0, 3, 4, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(results_vector, gold));
}

TEST_CASE(pad_test_highest_half)
{
migraphx::program p;
Expand Down
42 changes: 42 additions & 0 deletions test/verify/test_pad_large.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct test_pad_large : verify_program<test_pad_large>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {586, 3, 224, 224}};
std::vector<int64_t> pads0 = {0, 0, 1, 1, 0, 0, 1, 1};
auto l0 = mm->add_parameter("x", s0);
mm->add_instruction(migraphx::make_op("pad", {{"pads", pads0}}), l0);
return p;
}
};

0 comments on commit 4d471bd

Please sign in to comment.