Skip to content

Commit

Permalink
Add compute_fp32 flag for quant_gemm tests (#1360) (#1409)
Browse files Browse the repository at this point in the history
test_gpu_pack_int8_args fails on gfx908 machine, because it doesn't set compute_fp32 flag correctly. This PR fixes the test such that it checks for the device-name, and rocblas-versions and sets this flag accordingly.
  • Loading branch information
causten authored Oct 6, 2022
1 parent 0092375 commit dea0a80
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 62 deletions.
6 changes: 5 additions & 1 deletion src/targets/gpu/include/migraphx/gpu/rocblas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_ROCBLAS_HPP

#include <migraphx/manage_ptr.hpp>
#include <migraphx/config.hpp>
#include <rocblas.h>
Expand All @@ -37,6 +36,11 @@ using rocblas_handle_ptr = MIGRAPHX_MANAGE_PTR(rocblas_handle, rocblas_destroy_h
rocblas_handle_ptr create_rocblas_handle_ptr();
rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s);

struct context;

bool get_compute_fp32_flag();

bool get_int8_x4_format(context& ctx);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Expand Down
18 changes: 3 additions & 15 deletions src/targets/gpu/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,26 +99,14 @@ struct miopen_apply
(void)i;
}

const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}

void init()
{
assert(mod != nullptr);
assert(pass != nullptr);

#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto& ctx = get_context();
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
auto& ctx = get_context();
int8_x4_format = get_int8_x4_format(ctx);
compute_fp32 = get_compute_fp32_flag();

offload_copy = (mod->name() == "main") ? pass->offload_copy : false;

Expand Down
33 changes: 33 additions & 0 deletions src/targets/gpu/rocblas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <unordered_set>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -41,6 +47,33 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s)
return rb;
}

const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}

bool get_compute_fp32_flag()
{
bool compute_fp32 = false;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
#endif
return compute_fp32;
}

bool get_int8_x4_format(context& ctx)
{
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
return int8_x4_format;
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
95 changes: 49 additions & 46 deletions test/gpu/pack_int8_args.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <migraphx/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/replace_allocate.hpp>
Expand All @@ -39,9 +40,8 @@
#include <migraphx/make_op.hpp>
#include <test.hpp>

void run_passes(migraphx::module& m)
void run_passes(migraphx::module& m, migraphx::gpu::context& ctx)
{
auto ctx = migraphx::gpu::context{};
migraphx::run_passes(m,
{migraphx::auto_contiguous{},
migraphx::gpu::lowering{&ctx, false},
Expand All @@ -52,18 +52,6 @@ void run_passes(migraphx::module& m)
migraphx::dead_code_elimination{}});
}

bool get_int8_x4_format()
{
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto ctx = migraphx::gpu::context{};
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
return int8_x4_format;
}

TEST_CASE(quant_dot)
{
auto create_module = [] {
Expand Down Expand Up @@ -102,11 +90,13 @@ TEST_CASE(quant_dot)
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m2_shape)}}));
packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), l2, alloc);
}
auto gemm =
m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}),
l1,
packa,
gemm_alloc);
auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm",
{{"int8_x4_format", int8_x4},
{"compute_fp32", migraphx::gpu::get_compute_fp32_flag()}}),
l1,
packa,
gemm_alloc);

auto beta_broadcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", m3_shape.lens()}}), beta);
Expand All @@ -124,11 +114,12 @@ TEST_CASE(quant_dot)
return m;
};

auto m1 = create_module();
run_passes(m1);
auto m1 = create_module();
auto ctx = migraphx::gpu::context{};
run_passes(m1, ctx);

bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
bool int8_x4 = migraphx::gpu::get_int8_x4_format(ctx);
auto m2 = create_optimized_int8_x4(int8_x4);
EXPECT(m1 == m2);
}

Expand Down Expand Up @@ -211,21 +202,24 @@ TEST_CASE(quant_dot_trans)
packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), contb, allocpb);
}

auto gemm =
m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}),
tl1_alpha_int8,
packb,
output);
auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm",
{{"int8_x4_format", int8_x4},
{"compute_fp32", migraphx::gpu::get_compute_fp32_flag()}}),
tl1_alpha_int8,
packb,
output);
m.add_return({gemm});

return m;
};

auto m1 = create_module();
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
auto m1 = create_module();
auto ctx = migraphx::gpu::context{};
run_passes(m1, ctx);

run_passes(m1);
bool int8_x4 = migraphx::gpu::get_int8_x4_format(ctx);
auto m2 = create_optimized_int8_x4(int8_x4);

EXPECT(m1 == m2);
}
Expand Down Expand Up @@ -292,11 +286,13 @@ TEST_CASE(quant_dot_pad)
packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pl2, alloc);
}

auto gemm =
m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}),
pl1,
packa,
gemm_alloc);
auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm",
{{"int8_x4_format", int8_x4},
{"compute_fp32", migraphx::gpu::get_compute_fp32_flag()}}),
pl1,
packa,
gemm_alloc);

auto beta_broadcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), beta);
Expand All @@ -313,11 +309,12 @@ TEST_CASE(quant_dot_pad)
return m;
};

auto m1 = create_module();
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
auto m1 = create_module();
auto ctx = migraphx::gpu::context{};
run_passes(m1, ctx);

run_passes(m1);
bool int8_x4 = migraphx::gpu::get_int8_x4_format(ctx);
auto m2 = create_optimized_int8_x4(int8_x4);

EXPECT(m1 == m2);
}
Expand Down Expand Up @@ -438,17 +435,23 @@ TEST_CASE(quant_dot_trans_pad)
}

auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}), pa, packb, output);
migraphx::make_op("gpu::quant_gemm",
{{"int8_x4_format", int8_x4},
{"compute_fp32", migraphx::gpu::get_compute_fp32_flag()}}),
pa,
packb,
output);
m.add_return({gemm});

return m;
};

auto m1 = create_module();
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
auto m1 = create_module();
auto ctx = migraphx::gpu::context{};
run_passes(m1, ctx);

run_passes(m1);
bool int8_x4 = migraphx::gpu::get_int8_x4_format(ctx);
auto m2 = create_optimized_int8_x4(int8_x4);

EXPECT(m1 == m2);
}
Expand Down

0 comments on commit dea0a80

Please sign in to comment.