diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 9c9e20b78a..b2d2804003 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -125,11 +125,14 @@ static bool specific_op(std::string_view option, bool fallback = false) return contains(options, option); } -bool mlir_attention_enabled() +bool mlir_attention_enabled(context* ctx) { #ifdef MIGRAPHX_MLIR if(not mlir_enabled()) return false; + // Enable attention by default for mi300 + if(ctx != nullptr and starts_with(ctx->get_current_device().get_gfx_name(), "gfx94")) + return true; return specific_op("attention"); #else return false; @@ -996,7 +999,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const }; // Attention offloads; default disabled - if(mlir_attention_enabled() or enable_extra) + if(mlir_attention_enabled(ctx) or enable_extra) { match::find_matches(mpm, find_mlir_attention_fused_ops{mlir_mode::all}); mpm.run_pass(dead_code_elimination{}); diff --git a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp index c7ee06a564..6a7966db89 100644 --- a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp @@ -34,7 +34,6 @@ struct module_pass_manager; namespace gpu { MIGRAPHX_GPU_EXPORT bool mlir_enabled(); -MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled(); struct MIGRAPHX_GPU_EXPORT fuse_mlir {