Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Remove unused MLIR components #2443

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs):
else:
self.ctx = Context()
register_all_dialects(self.ctx)
quake.register_dialect(self.ctx)
cc.register_dialect(self.ctx)
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
self.loc = Location.unknown(context=self.ctx)
self.module = Module.create(loc=self.loc)
Expand Down
5 changes: 2 additions & 3 deletions python/cudaq/kernel/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

# We need static initializers to run in the CAPI `ExecutionEngine`,
# so here we run a simple JIT compile at global scope
with Context():
with Context() as ctx:
register_all_dialects(ctx)
module = Module.parse(r"""
llvm.func @none() {
llvm.return
Expand Down Expand Up @@ -246,8 +247,6 @@ class PyKernel(object):
def __init__(self, argTypeList):
self.ctx = Context()
register_all_dialects(self.ctx)
quake.register_dialect(self.ctx)
cc.register_dialect(self.ctx)
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)

self.metadata = {'conditionalOnMeasure': False}
Expand Down
9 changes: 9 additions & 0 deletions python/cudaq/mlir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# ============================================================================ #
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

from ._mlir_libs._quakeDialects import register_all_dialects
22 changes: 15 additions & 7 deletions python/extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,14 @@ add_mlir_python_common_capi_library(CUDAQuantumPythonCAPI
RELATIVE_INSTALL_ROOT "../.."
DECLARED_SOURCES
CUDAQuantumPythonSources
# TODO: Remove this in favor of showing fine grained registration once
# available.
MLIRPythonExtension.RegisterEverything
MLIRPythonSources.Core
MLIRPythonSources.Dialects.arith
MLIRPythonSources.Dialects.builtin
MLIRPythonSources.Dialects.cf
MLIRPythonSources.Dialects.complex
MLIRPythonSources.Dialects.func
MLIRPythonSources.Dialects.math
MLIRPythonSources.ExecutionEngine
)

################################################################################
Expand All @@ -134,10 +138,14 @@ add_mlir_python_modules(CUDAQuantumPythonModules
INSTALL_PREFIX "cudaq/mlir"
DECLARED_SOURCES
CUDAQuantumPythonSources
# TODO: Remove this in favor of showing fine grained registration once
# available.
MLIRPythonExtension.RegisterEverything
MLIRPythonSources
MLIRPythonSources.Core
MLIRPythonSources.Dialects.arith
MLIRPythonSources.Dialects.builtin
MLIRPythonSources.Dialects.cf
MLIRPythonSources.Dialects.complex
MLIRPythonSources.Dialects.func
MLIRPythonSources.Dialects.math
MLIRPythonSources.ExecutionEngine
COMMON_CAPI_LINK_LIBS
CUDAQuantumPythonCAPI
)
Expand Down
65 changes: 13 additions & 52 deletions python/runtime/mlir/py_register_dialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,13 @@
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "mlir/Bindings/Python/PybindAdaptors.h"

#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/CAPI/Dialects.h"
#include "cudaq/Optimizer/CodeGen/Passes.h"
#include "cudaq/Optimizer/CodeGen/Pipelines.h"
#include "cudaq/Optimizer/Dialect/CC/CCDialect.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/CC/CCTypes.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "mlir/InitAllDialects.h"
#include "cudaq/Optimizer/InitAllDialects.h"
#include "cudaq/Optimizer/InitAllPasses.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/CAPI/IR.h"
#include <fmt/core.h>
#include <pybind11/complex.h>
#include <pybind11/stl.h>
Expand All @@ -28,32 +22,10 @@ using namespace mlir::python::adaptors;
using namespace mlir;

namespace cudaq {
static bool registered = false;

void registerQuakeDialectAndTypes(py::module &m) {
void registerQuakeTypes(py::module &m) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void registerQuakeTypes(py::module &m) {
static void registerQuakeTypes(py::module &m) {

auto quakeMod = m.def_submodule("quake");

quakeMod.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle handle = mlirGetDialectHandle__quake__();
mlirDialectHandleRegisterDialect(handle, context);
if (load) {
mlirDialectHandleLoadDialect(handle, context);
}

if (!registered) {
cudaq::opt::registerOptCodeGenPasses();
cudaq::opt::registerOptTransformsPasses();
cudaq::opt::registerAggressiveEarlyInlining();
cudaq::opt::registerUnrollingPipeline();
cudaq::opt::registerTargetPipelines();
cudaq::opt::registerMappingPipeline();
registered = true;
}
},
py::arg("context") = py::none(), py::arg("load") = true);

mlir_type_subclass(quakeMod, "RefType", [](MlirType type) {
return unwrap(type).isa<quake::RefType>();
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
Expand Down Expand Up @@ -143,21 +115,10 @@ void registerQuakeDialectAndTypes(py::module &m) {
});
}

void registerCCDialectAndTypes(py::module &m) {
void registerCCTypes(py::module &m) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void registerCCTypes(py::module &m) {
static void registerCCTypes(py::module &m) {

Assuming that's ok, should move these 2 out of the namespace as well.


auto ccMod = m.def_submodule("cc");

ccMod.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__();
mlirDialectHandleRegisterDialect(ccHandle, context);
if (load) {
mlirDialectHandleLoadDialect(ccHandle, context);
}
},
py::arg("context") = py::none(), py::arg("load") = true);

mlir_type_subclass(ccMod, "CharspanType", [](MlirType type) {
return unwrap(type).isa<cudaq::cc::CharspanType>();
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
Expand Down Expand Up @@ -298,9 +259,6 @@ void registerCCDialectAndTypes(py::module &m) {
}

void bindRegisterDialects(py::module &mod) {
registerQuakeDialectAndTypes(mod);
registerCCDialectAndTypes(mod);

mod.def("load_intrinsic", [](MlirModule module, std::string name) {
auto unwrapped = unwrap(module);
cudaq::IRBuilder builder = IRBuilder::atBlockEnd(unwrapped.getBody());
Expand All @@ -310,14 +268,17 @@ void bindRegisterDialects(py::module &mod) {

mod.def("register_all_dialects", [](MlirContext context) {
DialectRegistry registry;
registry.insert<quake::QuakeDialect, cudaq::cc::CCDialect>();
cudaq::opt::registerCodeGenDialect(registry);
registerAllDialects(registry);
auto *mlirContext = unwrap(context);
cudaq::registerAllDialects(registry);
MLIRContext *mlirContext = unwrap(context);
mlirContext->appendDialectRegistry(registry);
mlirContext->loadAllAvailableDialects();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does loadAllAvailableDialects actually load Quake, CC, and CodeGen?

Do we need an explicit test to verify that all the dialects we expect to be loaded are loaded?

});

// Register type as passes once, when the module is loaded.
registerQuakeTypes(mod);
registerCCTypes(mod);
cudaq::registerAllPasses();

mod.def("gen_vector_of_complex_constant", [](MlirLocation loc,
MlirModule module,
std::string name,
Expand Down
3 changes: 2 additions & 1 deletion python/tests/mlir/bare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

# RUN: PYTHONPATH=../../ python3 %s | FileCheck %s

from cudaq.mlir import register_all_dialects
from cudaq.mlir.ir import *
from cudaq.mlir.dialects import quake
from cudaq.mlir.dialects import builtin, func, arith

with Context() as ctx:
quake.register_dialect()
register_all_dialects(ctx)
m = Module.create(loc=Location.unknown())
with InsertionPoint(m.body), Location.unknown():
f = func.FuncOp('main', ([], []))
Expand Down
Loading