Skip to content

Commit

Permalink
Merge branch 'develop' into bf16_verify_onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
richagadgil authored Nov 25, 2024
2 parents 141bd27 + cd37826 commit 9d03a89
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ include(ROCMSetupVersion)

option(BUILD_DEV "Build for development purpose only" OFF)

rocm_setup_version(VERSION 2.11.0)
rocm_setup_version(VERSION 2.12.0)
math(EXPR MIGRAPHX_SO_MAJOR_VERSION "(${PROJECT_VERSION_MAJOR} * 1000 * 1000) + (${PROJECT_VERSION_MINOR} * 1000) + ${PROJECT_VERSION_PATCH}")
set(MIGRAPHX_SO_VERSION ${MIGRAPHX_SO_MAJOR_VERSION}.0)

Expand Down
23 changes: 20 additions & 3 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ auto conv_const_weights()
match::args(match::none_of(match::is_constant()), match::is_constant().bind("w")));
}

auto from_int4()
{
return match::make_predicate_matcher([](instruction_ref start) {
return fix<bool>([&](auto self, instruction_ref ins) {
auto alias = instruction::get_output_alias(ins);
if(contains({"reshape", "dequantizelinear"}, alias->name()))
return self(alias->inputs().front());
if(alias->name() == "concat")
return all_of(alias->inputs(), self);
return alias->name() == "unpack_int4";
})(start);
});
}

auto not_from_int4() { return match::none_of(from_int4()); }

auto reduction() { return match::name_contains("reduce"); }

// conv(x, w) * a => conv(x, a * w)
Expand Down Expand Up @@ -208,8 +224,8 @@ struct find_mul_dot
{
auto matcher() const
{
auto is_dot_const_inputs =
match::name("dot")(match::any_of[match::inputs()](match::is_constant()));
auto constant = match::is_constant(not_from_int4());
auto is_dot_const_inputs = match::name("dot")(match::any_of[match::inputs()](constant));
return match::name("mul")(match::either_arg(0, 1)(
is_dot_const_inputs.bind("dot"), match::name("broadcast", "multibroadcast").bind("c")));
}
Expand Down Expand Up @@ -358,7 +374,8 @@ struct find_dot_mul
match::used_once(),
match::either_arg(0, 1)(const_broadcast.bind("d"),
match::none_of(match::is_constant()).bind("z")));
return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c")));
return match::name("dot")(
match::either_arg(0, 1)(mul, match::is_constant(not_from_int4()).bind("c")));
}

void apply(module& m, const match::matcher_result& r) const
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ bool mlir_attention_enabled(context* ctx)
#ifdef MIGRAPHX_MLIR
if(not mlir_enabled())
return false;
if(specific_op<rejected>("attention"))
return false;
// Enable attention by default for mi300
if(ctx != nullptr and starts_with(ctx->get_current_device().get_gfx_name(), "gfx94"))
return true;
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/.onnxrt-commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ac9c135b9543ad0374fe335bc3dc5feb0f24f010
b1ccbe2a8efed30b749207b1a29ae03c50289040
66 changes: 66 additions & 0 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4024,6 +4024,72 @@ TEST_CASE(mul_dot_b_not_k_broadcast)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(mul_dot_a_int4_dq)
{
migraphx::shape as{migraphx::shape::float_type, {1, 32, 4096}};
migraphx::shape bs{migraphx::shape::int8_type, {22016, 2048}};
migraphx::shape cs{migraphx::shape::float_type, {22016, 4096}};
migraphx::module m1;
{
auto a = m1.add_parameter("input", as);

auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4096}}));
auto litb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", as.lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), a, litb);

auto b = m1.add_literal(migraphx::generate_literal(bs));
auto unpack = m1.add_instruction(migraphx::make_op("unpack_int4"), b);
auto scales = m1.add_literal(migraphx::generate_literal(cs));
auto dq = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack, scales);
auto unsqueeze = m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), dq);
auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze);
auto dot = m1.add_instruction(migraphx::make_op("dot"), mul, transpose);
m1.add_return({dot});
};
migraphx::module m2 = m1;
run_pass(m1);

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(mul_dot_a_int4_dq_concat)
{
migraphx::shape as{migraphx::shape::float_type, {1, 32, 4096}};
migraphx::shape bs{migraphx::shape::int8_type, {4096, 5504}};
migraphx::shape cs{migraphx::shape::float_type, {4096, 11008}};
migraphx::module m1;
{
auto a = m1.add_parameter("input", as);

auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4096}}));
auto litb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", as.lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), a, litb);

std::vector<migraphx::instruction_ref> concats;
for(int i = 0; i < 2; i++)
{
auto b = m1.add_literal(migraphx::generate_literal(bs));
auto unpack = m1.add_instruction(migraphx::make_op("unpack_int4"), b);
auto scales = m1.add_literal(migraphx::generate_literal(cs));
auto dq = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack, scales);
concats.push_back(
m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), dq));
}
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), concats);
auto dot = m1.add_instruction(migraphx::make_op("dot"), mul, concat);
m1.add_return({dot});
};
migraphx::module m2 = m1;
run_pass(m1);

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_mul_a)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
Expand Down

0 comments on commit 9d03a89

Please sign in to comment.