Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into gqa-jit
Browse files Browse the repository at this point in the history
  • Loading branch information
turneram committed Oct 17, 2024
2 parents 29adc77 + 275f854 commit 3a3707c
Show file tree
Hide file tree
Showing 13 changed files with 364 additions and 66 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/[email protected] -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocMLIR@424059e63114ba32827a9f022233fc15ed8e3378 -DBUILD_FAT_LIBROCKCOMPILER=On
ROCm/rocMLIR@82378ac44b8627ecbaf9078353fb53588090fe28 -DBUILD_FAT_LIBROCKCOMPILER=On
81 changes: 47 additions & 34 deletions src/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,49 +151,57 @@ shape common_shape(const std::vector<shape>& shapes)
return {compute_common_types(shapes), compute_common_lens(shapes)};
}

std::vector<instruction_ref>
insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> inputs)
std::vector<instruction_ref> insert_common_args(module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
common_options options)
{
if(std::any_of(
inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
{
auto input_shapes = to_shapes(inputs);
auto c_type = compute_common_types(input_shapes);
auto c_dyn_dims = compute_common_dyn_dims(input_shapes);
if(options.common_lens)
{
auto c_dyn_dims = compute_common_dyn_dims(input_shapes);

auto s0 = inputs[0]->get_shape();
// always add both multibroadcast instructions for dynamic shapes
inputs[0] = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) {
// uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime
auto s = input->get_shape();
return m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
input,
inputs[0]);
});
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type)
{
input =
m.insert_instruction(ins, make_op("convert", {{"target_type", c_type}}), input);
}
return input;
});
auto s0 = inputs[0]->get_shape();
// always add both multibroadcast instructions for dynamic shapes
inputs[0] = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) {
// uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime
auto s = input->get_shape();
return m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
input,
inputs[0]);
});
}
if(options.common_type)
{
auto c_type = compute_common_types(input_shapes);
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type)
{
input = m.insert_instruction(
ins, make_op("convert", {{"target_type", c_type}}), input);
}
return input;
});
}
}
else
{
auto common = common_shape(to_shapes(inputs));
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().lens() != common.lens())
if(options.common_lens and input->get_shape().lens() != common.lens())
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
}
if(input->get_shape().type() != common.type())
if(options.common_type and input->get_shape().type() != common.type())
{
input = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input);
Expand All @@ -204,22 +212,27 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
return inputs;
}

std::vector<instruction_ref> add_common_args(module& m, std::vector<instruction_ref> inputs)
std::vector<instruction_ref>
add_common_args(module& m, std::vector<instruction_ref> inputs, common_options options)
{
return insert_common_args(m, m.end(), std::move(inputs));
return insert_common_args(m, m.end(), std::move(inputs), options);
}

instruction_ref insert_common_op(module& m,
instruction_ref ins,
const operation& op,
std::vector<instruction_ref> inputs)
std::vector<instruction_ref> inputs,
common_options options)
{
return m.insert_instruction(ins, op, insert_common_args(m, ins, std::move(inputs)));
return m.insert_instruction(ins, op, insert_common_args(m, ins, std::move(inputs), options));
}

instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs)
instruction_ref add_common_op(module& m,
const operation& op,
std::vector<instruction_ref> inputs,
common_options options)
{
return insert_common_op(m, m.end(), op, std::move(inputs));
return insert_common_op(m, m.end(), op, std::move(inputs), options);
}

shape make_bcast_shape(const shape& input_shape, const std::vector<std::size_t>& bcast_lens)
Expand Down
1 change: 1 addition & 0 deletions src/driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ add_executable(driver
main.cpp
verify.cpp
passes.cpp
mlir.cpp
models.cpp
perf.cpp
marker_roctx.cpp
Expand Down
5 changes: 5 additions & 0 deletions src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "verify_options.hpp"
#include "argument_parser.hpp"
#include "command.hpp"
#include "mlir.hpp"
#include "precision.hpp"
#include "passes.hpp"
#include "perf.hpp"
Expand Down Expand Up @@ -77,6 +78,7 @@ struct loader
bool is_test = false;
unsigned trim = 0;
bool optimize = false;
bool mlir = false;
bool skip_unknown_operators = false;
bool brief = false;
std::string output_type;
Expand Down Expand Up @@ -140,6 +142,7 @@ struct loader
ap.append(),
ap.nargs(2));
ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
ap(mlir, {"--mlir"}, ap.help("Offload everything to mlir"), ap.set_value(true));
ap(passes, {"--apply-pass", "-p"}, ap.help("Passes to apply to model"), ap.append());
ap(output_type,
{"--graphviz", "-g"},
Expand Down Expand Up @@ -374,6 +377,8 @@ struct loader
}
if(not passes.empty())
migraphx::run_passes(p, get_passes(passes));
if(mlir)
offload_to_mlir(p);
return p;
}

Expand Down
59 changes: 59 additions & 0 deletions src/driver/mlir.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "mlir.hpp"
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/param_utils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <unordered_map>

namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {

void offload_to_mlir(program& p)
{
auto* mm = p.get_main_module();
auto* mlirm = p.create_module("mlir");
mlirm->set_bypass();
std::vector<instruction_ref> inputs;
copy_if(iterator_for(*mm), std::back_inserter(inputs), [&](instruction_ref ins) {
if(ins->name() == "@param")
return true;
if(ins->name() == "@literal")
return ins->get_shape().elements() != 1;
return false;
});

std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::size_t n = 0;
for(auto ins : inputs)
{
map_ins[ins] = mlirm->add_parameter(param_name(n++), ins->get_shape().as_standard());
}

auto mlir_last = mlirm->add_instructions(mm, &map_ins);
mlirm->add_return(mlir_last);

auto last = std::prev(mm->end());
auto mlir_op = mm->insert_instruction(last, make_op("gpu::mlir_op"), inputs, {mlirm});
if(mlir_last.size() > 1)
{
std::vector<instruction_ref> outputs;
transform(range(mlir_last.size()), std::back_inserter(outputs), [&](auto i) {
return mm->insert_instruction(last, make_op("get_tuple_elem", {{"index", i}}), mlir_op);
});
mm->replace_return(outputs);
}
else
{
mm->replace_return({mlir_op});
}
run_passes(*mm, {dead_code_elimination{}});
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
16 changes: 16 additions & 0 deletions src/driver/mlir.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef MIGRAPHX_GUARD_DRIVER_MLIR_HPP
#define MIGRAPHX_GUARD_DRIVER_MLIR_HPP

#include <migraphx/config.hpp>
#include <migraphx/program.hpp>

namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {

void offload_to_mlir(program& p);

} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif // MIGRAPHX_GUARD_DRIVER_MLIR_HPP
23 changes: 18 additions & 5 deletions src/include/migraphx/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct operation;

struct common_options
{
bool common_type = true;
bool common_lens = true;
};

/**
* Broadcasting works by comparing the shapes element-wise starting with
* the trailing (right-most) dimensions and working leftwards. This is equivalent
Expand Down Expand Up @@ -112,23 +118,30 @@ std::vector<shape::dynamic_dimension> compute_common_dyn_dims(const std::vector<
* attached to each instruction_ref are considered for broadcasting
* @return std::vector<instruction_ref> a modified argument list
*/
MIGRAPHX_EXPORT std::vector<instruction_ref>
insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> inputs);
MIGRAPHX_EXPORT std::vector<instruction_ref> insert_common_args(module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
common_options options = {});

MIGRAPHX_EXPORT
std::vector<instruction_ref> add_common_args(module& m, std::vector<instruction_ref> inputs);
std::vector<instruction_ref>
add_common_args(module& m, std::vector<instruction_ref> inputs, common_options options = {});

MIGRAPHX_EXPORT
instruction_ref insert_common_op(module& m,
instruction_ref ins,
const operation& op,
std::vector<instruction_ref> inputs);
std::vector<instruction_ref> inputs,
common_options options = {});

/**
* @brief Wrapper for insert_common_args() which inserts operation at the end of the module.
*/
MIGRAPHX_EXPORT
instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs);
instruction_ref add_common_op(module& m,
const operation& op,
std::vector<instruction_ref> inputs,
common_options options = {});

/**
* Calculates the broadcasted shape with the given input_shape and broadcasted dimensions.
Expand Down
13 changes: 11 additions & 2 deletions src/instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <migraphx/erase.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
#include <deque>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -62,10 +63,18 @@ void instruction::replace(const shape& r)
if(r != result)
{
result = r;
for(auto&& ins : output)
std::deque<instruction_ref> q(output.begin(), output.end());
while(not q.empty())
{
instruction_ref ins = q.front();
q.pop_front();
assert(ins->name() == "@return" or ins->name().front() != '@');
ins->recompute_shape();
shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args);
if(new_r != ins->result)
{
ins->result = new_r;
std::copy(ins->output.begin(), ins->output.end(), std::back_inserter(q));
}
}
}
}
Expand Down
8 changes: 7 additions & 1 deletion src/targets/gpu/compile_hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,15 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
bool hip_has_flags(const std::vector<std::string>& flags)
{
hiprtc_program prog{" "};

std::string src = " ";
src_file input{"main.cpp", src};
std::vector<src_file> srcs = {input};

try
{
prog.compile(flags, true);
std::string arch = "gfx900";
compile_hip_src(srcs, flags, arch);
return true;
}
catch(...)
Expand Down
Loading

0 comments on commit 3a3707c

Please sign in to comment.