Skip to content

Commit

Permalink
[tuner]: add c/python binding for querying mma intrinsic (#19218)
Browse files Browse the repository at this point in the history
After this PR: #19199

add Python bindings to these two utility functions to querying mma
intrinsic instructions from input module.

---------

Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu authored Nov 20, 2024
1 parent 1654ce6 commit e1ce3fa
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 2 deletions.
8 changes: 8 additions & 0 deletions compiler/bindings/c/iree/compiler/dialects/iree_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ MLIR_CAPI_EXPORTED MlirAttribute ireeCodegenCompilationInfoAttrGet(
MLIR_CAPI_EXPORTED ireeCodegenCompilationInfoParameters
ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr);

MLIR_CAPI_EXPORTED void
ireeCodegenGetExecutableVariantOps(MlirModule module, size_t *numOps,
MlirOperation *executableOps);

MLIR_CAPI_EXPORTED void ireeCodegenQueryMMAIntrinsics(MlirOperation op,
size_t *numIntrinsics,
uint32_t *mmaIntrinsics);

#ifdef __cplusplus
}
#endif
Expand Down
45 changes: 45 additions & 0 deletions compiler/bindings/python/IREECompilerDialectsModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,33 @@ static const char *kGpuModuleImportPath =
namespace py = pybind11;
using namespace mlir::python::adaptors;

static std::vector<MlirOperation>
ireeCodegenGetExecutableVariantOpsBinding(MlirModule module) {
size_t numOps = 0;
ireeCodegenGetExecutableVariantOps(module, &numOps, nullptr);
std::vector<MlirOperation> ops(numOps);
ireeCodegenGetExecutableVariantOps(module, &numOps, ops.data());

return ops;
}

static std::vector<py::object>
ireeCodegenQueryMMAIntrinsicsBinding(MlirOperation op) {
size_t numMMAs = 0;
ireeCodegenQueryMMAIntrinsics(op, &numMMAs, nullptr);
std::vector<uint32_t> mmaIntrinsics(numMMAs);
ireeCodegenQueryMMAIntrinsics(op, &numMMAs, mmaIntrinsics.data());

py::object mmaIntrinsicEnum =
py::module_::import(kGpuModuleImportPath).attr("MMAIntrinsic");
std::vector<py::object> mmaList(numMMAs);
for (size_t i = 0; i < numMMAs; ++i) {
mmaList[i] = mmaIntrinsicEnum(mmaIntrinsics[i]);
}

return mmaList;
}

PYBIND11_MODULE(_ireeCompilerDialects, m) {
m.doc() = "iree-compiler dialects python extension";

Expand Down Expand Up @@ -326,4 +353,22 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
"Gets an #iree_gpu.lowering_config from parameters.")
.def_property_readonly("attributes",
ireeGPULoweringConfigAttrGetAttributes);

//===-------------------------------------------------------------------===//
// Binding to utility function getExecutableVariantOps
//===-------------------------------------------------------------------===//

iree_codegen_module.def(
"get_executable_variant_ops", &ireeCodegenGetExecutableVariantOpsBinding,
"Gets the executable variant operations from a module.",
py::arg("module"));

//===-------------------------------------------------------------------===//
// Binding to utility function queryMMAIntrinsics
//===-------------------------------------------------------------------===//

iree_codegen_module.def(
"query_mma_intrinsics", &ireeCodegenQueryMMAIntrinsicsBinding,
"Queries the MMA intrinsics from an executable variant op.",
py::arg("op"));
}
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/API/Internal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ iree_compiler_cc_library(
deps = [
"//compiler/bindings/c:headers",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Utils",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:IR",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/API/Internal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ iree_cc_library(
MLIRCAPIIR
MLIRIR
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Utils
iree::compiler::bindings::c::headers
PUBLIC
)
Expand Down
49 changes: 49 additions & 0 deletions compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <type_traits>
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/dialects/iree_codegen.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
Expand All @@ -24,6 +25,8 @@ using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipeline;
using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipelineAttr;
using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttrInterface;
using mlir::iree_compiler::IREE::Codegen::TranslationInfoAttr;
using mlir::iree_compiler::IREE::GPU::MMAIntrinsic;
using mlir::iree_compiler::IREE::HAL::ExecutableVariantOp;

bool ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr(
MlirAttribute attr) {
Expand Down Expand Up @@ -149,3 +152,49 @@ ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr) {
parameters.translationInfo = wrap(compilationInfo.getTranslationInfo());
return parameters;
}

void ireeCodegenGetExecutableVariantOps(MlirModule module, size_t *numOps,
MlirOperation *executableOps) {
assert(!mlirModuleIsNull(module) && "module cannot be nullptr");
assert(numOps && "numOps cannot be nullptr");

mlir::ModuleOp moduleOp = unwrap(module);
llvm::SmallVector<ExecutableVariantOp> executableVariantOps =
mlir::iree_compiler::getExecutableVariantOps(moduleOp);

if (!executableOps) {
*numOps = executableVariantOps.size();
return;
}

assert(
*numOps == executableVariantOps.size() &&
"*numOps must match the number of elements in the executableVariantOps");

for (size_t i = 0, e = executableVariantOps.size(); i < e; ++i) {
executableOps[i] = wrap(executableVariantOps[i]);
}
}

void ireeCodegenQueryMMAIntrinsics(MlirOperation op, size_t *numIntrinsics,
uint32_t *mmaIntrinsics) {
assert(numIntrinsics && "numIntrinsics cannot be nullptr");

mlir::Operation *mlirOp = unwrap(op);
auto variantOp = llvm::dyn_cast_if_present<ExecutableVariantOp>(mlirOp);
assert(variantOp && "operation is not a ExecutableVariantOp");

llvm::SmallVector<MMAIntrinsic> intrinsics =
mlir::iree_compiler::queryMMAIntrinsics(variantOp);
if (!mmaIntrinsics) {
*numIntrinsics = intrinsics.size();
return;
}

assert(*numIntrinsics == intrinsics.size() &&
"*numIntrinsics must match the number of elements in the intrinsics");

for (size_t i = 0, e = intrinsics.size(); i < e; ++i) {
mmaIntrinsics[i] = static_cast<uint32_t>(intrinsics[i]);
}
}
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ extern void ireeCodegenCompilationInfoAttrGetTypeID();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGet();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID();
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetValue();
extern void ireeCodegenGetExecutableVariantOps();
extern void ireeCodegenQueryMMAIntrinsics();
extern void ireeCodegenTranslationInfoAttrGet();
extern void ireeCodegenTranslationInfoAttrGetParameters();
extern void ireeCodegenTranslationInfoAttrGetTypeID();
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.def
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ EXPORTS
ireeCodegenDispatchLoweringPassPipelineAttrGet
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
ireeCodegenDispatchLoweringPassPipelineAttrGetValue
ireeCodegenGetExecutableVariantOps
ireeCodegenQueryMMAIntrinsics
ireeCodegenTranslationInfoAttrGet
ireeCodegenTranslationInfoAttrGetParameters
ireeCodegenTranslationInfoAttrGetTypeID
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.ld
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ VER_0 {
ireeCodegenDispatchLoweringPassPipelineAttrGet;
ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID;
ireeCodegenDispatchLoweringPassPipelineAttrGetValue;
ireeCodegenGetExecutableVariantOps;
ireeCodegenQueryMMAIntrinsics;
ireeCodegenTranslationInfoAttrGet;
ireeCodegenTranslationInfoAttrGetParameters;
ireeCodegenTranslationInfoAttrGetTypeID;
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.macos.lst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ _ireeCodegenCompilationInfoAttrGetTypeID
_ireeCodegenDispatchLoweringPassPipelineAttrGet
_ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
_ireeCodegenDispatchLoweringPassPipelineAttrGetValue
_ireeCodegenGetExecutableVariantOps
_ireeCodegenQueryMMAIntrinsics
_ireeCodegenTranslationInfoAttrGet
_ireeCodegenTranslationInfoAttrGetParameters
_ireeCodegenTranslationInfoAttrGetTypeID
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func) {

SmallVector<IREE::HAL::ExecutableVariantOp>
getExecutableVariantOps(mlir::ModuleOp moduleOp) {
llvm::SmallVector<IREE::HAL::ExecutableVariantOp> executableVariantOps;
SmallVector<IREE::HAL::ExecutableVariantOp> executableVariantOps;
moduleOp.walk([&](IREE::HAL::ExecutableVariantOp executableOp) {
executableVariantOps.push_back(executableOp);
});
Expand All @@ -1039,7 +1039,7 @@ getExecutableVariantOps(mlir::ModuleOp moduleOp) {

SmallVector<IREE::GPU::MMAIntrinsic>
queryMMAIntrinsics(IREE::HAL::ExecutableVariantOp executableOp) {
llvm::SmallVector<IREE::GPU::MMAIntrinsic> mmaIntrinsics;
SmallVector<IREE::GPU::MMAIntrinsic> mmaIntrinsics;
if (IREE::GPU::TargetAttr target = getGPUTargetAttr(executableOp)) {
mmaIntrinsics = llvm::map_to_vector(
target.getWgp().getMma(),
Expand Down

0 comments on commit e1ce3fa

Please sign in to comment.