Skip to content

Commit

Permalink
[RTG] Add Python Bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed Nov 24, 2024
1 parent a0999d6 commit c2d53a7
Show file tree
Hide file tree
Showing 13 changed files with 445 additions and 27 deletions.
1 change: 0 additions & 1 deletion include/circt/Dialect/RTGTest/IR/RTGTestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class RTGTestOp<string mnemonic, list<Trait> traits = []> :

def CPUDeclOp : RTGTestOp<"cpu_decl", [
Pure,
ConstantLike,
ContextResourceDefining,
]> {
let summary = "declare a CPU";
Expand Down
69 changes: 69 additions & 0 deletions integration_test/Bindings/Python/dialects/rtg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# REQUIRES: bindings_python
# RUN: %PYTHON% %s | FileCheck %s

import circt

from circt.dialects import rtg, rtgtest
from circt.ir import Context, Location, Module, InsertionPoint, Block, StringAttr, TypeAttr
from circt.passmanager import PassManager
from circt import rtgtool_support as rtgtool

with Context() as ctx, Location.unknown():
circt.register_dialects(ctx)
m = Module.create()
with InsertionPoint(m.body):
cpuTy = rtgtest.CPUType.get()
dictTy = rtg.DictType.get(
ctx,
[StringAttr.get('cpu0'), StringAttr.get('cpu1')], [cpuTy, cpuTy])

target = rtg.TargetOp('target_name', TypeAttr.get(dictTy))
targetBlock = Block.create_at_start(target.bodyRegion, [])
with InsertionPoint(targetBlock):
cpu0 = rtgtest.CPUDeclOp(cpuTy, 0)
cpu1 = rtgtest.CPUDeclOp(cpuTy, 1)
rtg.YieldOp([cpu0, cpu1])

test = rtg.TestOp('test_name', TypeAttr.get(dictTy))
Block.create_at_start(test.bodyRegion, [cpuTy, cpuTy])

# CHECK: rtg.target @target_name : !rtg.dict<cpu0: !rtgtest.cpu, cpu1: !rtgtest.cpu> {
# CHECK: [[V0:%.+]] = rtgtest.cpu_decl 0
# CHECK: [[V1:%.+]] = rtgtest.cpu_decl 1
# CHECK: rtg.yield [[V0]], [[V1]] : !rtgtest.cpu, !rtgtest.cpu
# CHECK: }
# CHECK: rtg.test @test_name : !rtg.dict<cpu0: !rtgtest.cpu, cpu1: !rtgtest.cpu> {
# CHECK: ^bb{{.*}}(%{{.*}}: !rtgtest.cpu, %{{.*}}: !rtgtest.cpu):
# CHECK: }
print(m)

with Context() as ctx, Location.unknown():
circt.register_dialects(ctx)
m = Module.create()
with InsertionPoint(m.body):
seq = rtg.SequenceOp('sequence_name')
Block.create_at_start(seq.bodyRegion, [])

test = rtg.TestOp('test_name', TypeAttr.get(rtg.DictType.get()))
block = Block.create_at_start(test.bodyRegion, [])
with InsertionPoint(block):
seq_closure = rtg.SequenceClosureOp('sequence_name', [])
rtg.InvokeSequenceOp(seq_closure)

# CHECK: rtg.test @test_name : !rtg.dict<> {
# CHECK-NEXT: rtg.sequence_closure
# CHECK-NEXT: rtg.invoke_sequence
# CHECK-NEXT: }
print(m)

pm = PassManager()
options = rtgtool.Options(
output_format=rtgtool.OutputFormat.ELABORATED_MLIR,
debug_mode=True,
)
rtgtool.populate_randomizer_pipeline(pm, options)
pm.run(m.operation)

# CHECK: rtg.test @test_name : !rtg.dict<> {
# CHECK-NEXT: }
print(m)
23 changes: 23 additions & 0 deletions lib/Bindings/Python/CIRCTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include "circt-c/Dialect/LTL.h"
#include "circt-c/Dialect/MSFT.h"
#include "circt-c/Dialect/OM.h"
#include "circt-c/Dialect/RTG.h"
#ifdef CIRCT_INCLUDE_TESTS
#include "circt-c/Dialect/RTGTest.h"
#endif
#include "circt-c/Dialect/SV.h"
#include "circt-c/Dialect/Seq.h"
#include "circt-c/Dialect/Verif.h"
Expand All @@ -45,6 +49,7 @@ static void registerPasses() {
registerFSMPasses();
registerHWArithPasses();
registerHWPasses();
registerRTGPasses();
registerHandshakePasses();
mlirRegisterConversionPasses();
mlirRegisterTransformsPasses();
Expand Down Expand Up @@ -96,6 +101,16 @@ PYBIND11_MODULE(_circt, m) {
mlirDialectHandleRegisterDialect(om, context);
mlirDialectHandleLoadDialect(om, context);

MlirDialectHandle rtg = mlirGetDialectHandle__rtg__();
mlirDialectHandleRegisterDialect(rtg, context);
mlirDialectHandleLoadDialect(rtg, context);

#ifdef CIRCT_INCLUDE_TESTS
MlirDialectHandle rtgtest = mlirGetDialectHandle__rtgtest__();
mlirDialectHandleRegisterDialect(rtgtest, context);
mlirDialectHandleLoadDialect(rtgtest, context);
#endif

MlirDialectHandle seq = mlirGetDialectHandle__seq__();
mlirDialectHandleRegisterDialect(seq, context);
mlirDialectHandleLoadDialect(seq, context);
Expand Down Expand Up @@ -143,6 +158,14 @@ PYBIND11_MODULE(_circt, m) {
circt::python::populateDialectSeqSubmodule(seq);
py::module om = m.def_submodule("_om", "OM API");
circt::python::populateDialectOMSubmodule(om);
py::module rtg = m.def_submodule("_rtg", "RTG API");
circt::python::populateDialectRTGSubmodule(rtg);
py::module rtgtool = m.def_submodule("_rtgtool", "RTGTool API");
circt::python::populateDialectRTGToolSubmodule(rtgtool);
#ifdef CIRCT_INCLUDE_TESTS
py::module rtgtest = m.def_submodule("_rtgtest", "RTGTest API");
circt::python::populateDialectRTGTestSubmodule(rtgtest);
#endif
py::module sv = m.def_submodule("_sv", "SV API");
circt::python::populateDialectSVSubmodule(sv);
py::module support = m.def_submodule("_support", "CIRCT support");
Expand Down
5 changes: 5 additions & 0 deletions lib/Bindings/Python/CIRCTModules.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ void populateDialectESISubmodule(pybind11::module &m);
void populateDialectHWSubmodule(pybind11::module &m);
void populateDialectMSFTSubmodule(pybind11::module &m);
void populateDialectOMSubmodule(pybind11::module &m);
void populateDialectRTGSubmodule(pybind11::module &m);
void populateDialectRTGToolSubmodule(pybind11::module &m);
#ifdef CIRCT_INCLUDE_TESTS
void populateDialectRTGTestSubmodule(pybind11::module &m);
#endif
void populateDialectSeqSubmodule(pybind11::module &m);
void populateDialectSVSubmodule(pybind11::module &m);
void populateSupportSubmodule(pybind11::module &m);
Expand Down
90 changes: 64 additions & 26 deletions lib/Bindings/Python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,57 @@ add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=circt.")
################################################################################
# Declare native Python extension
################################################################################
set(LLVM_OPTIONAL_SOURCES
RTGTestModule.cpp
)

set(PYTHON_BINDINGS_SOURCES
CIRCTModule.cpp
ESIModule.cpp
HWModule.cpp
OMModule.cpp
RTGModule.cpp
RTGToolModule.cpp
MSFTModule.cpp
SeqModule.cpp
SupportModule.cpp
SVModule.cpp
)

set(PYTHON_BINDINGS_LINK_LIBS
CIRCTCAPIArc
CIRCTCAPIComb
CIRCTCAPIDebug
CIRCTCAPIEmit
CIRCTCAPIESI
CIRCTCAPIExportVerilog
CIRCTCAPIFSM
CIRCTCAPIHW
CIRCTCAPIHWArith
CIRCTCAPIHandshake
CIRCTCAPILTL
CIRCTCAPIMSFT
CIRCTCAPIOM
CIRCTCAPIRTG
CIRCTCAPISV
CIRCTCAPISeq
CIRCTCAPIVerif
CIRCTCAPIConversion
CIRCTCAPIRtgTool
MLIRCAPITransforms
)

if (CIRCT_INCLUDE_TESTS)
list(APPEND PYTHON_BINDINGS_SOURCES RTGTestModule.cpp)
list(APPEND PYTHON_BINDINGS_LINK_LIBS CIRCTCAPIRTGTest)
endif()

declare_mlir_python_extension(CIRCTBindingsPythonExtension
MODULE_NAME _circt
SOURCES
CIRCTModule.cpp
ESIModule.cpp
HWModule.cpp
OMModule.cpp
MSFTModule.cpp
SeqModule.cpp
SupportModule.cpp
SVModule.cpp
${PYTHON_BINDINGS_SOURCES}
EMBED_CAPI_LINK_LIBS
CIRCTCAPIArc
CIRCTCAPIComb
CIRCTCAPIDebug
CIRCTCAPIEmit
CIRCTCAPIESI
CIRCTCAPIExportVerilog
CIRCTCAPIFSM
CIRCTCAPIHW
CIRCTCAPIHWArith
CIRCTCAPIHandshake
CIRCTCAPILTL
CIRCTCAPIMSFT
CIRCTCAPIOM
CIRCTCAPISV
CIRCTCAPISeq
CIRCTCAPIVerif
CIRCTCAPIConversion
MLIRCAPITransforms
${PYTHON_BINDINGS_LINK_LIBS}
PRIVATE_LINK_LIBS
LLVMSupport
)
Expand All @@ -56,6 +75,7 @@ declare_mlir_python_sources(CIRCTBindingsPythonSources
SOURCES
__init__.py
support.py
rtgtool_support.py
)

################################################################################
Expand Down Expand Up @@ -117,6 +137,24 @@ declare_mlir_dialect_python_bindings(
dialects/om.py
DIALECT_NAME om)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT CIRCTBindingsPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
TD_FILE dialects/RTGOps.td
SOURCES
dialects/rtg.py
DIALECT_NAME rtg)

if (CIRCT_INCLUDE_TESTS)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT CIRCTBindingsPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
TD_FILE dialects/RTGTestOps.td
SOURCES
dialects/rtgtest.py
DIALECT_NAME rtgtest)
endif()

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT CIRCTBindingsPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
Expand Down
62 changes: 62 additions & 0 deletions lib/Bindings/Python/RTGModule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//===- RTGModule.cpp - RTG API pybind module ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "CIRCTModules.h"

#include "circt-c/Dialect/RTG.h"

#include "mlir/Bindings/Python/PybindAdaptors.h"

#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
namespace py = pybind11;

using namespace circt;
using namespace mlir::python::adaptors;

/// Populate the rtg python module.
void circt::python::populateDialectRTGSubmodule(py::module &m) {
m.doc() = "RTG dialect Python native extension";

mlir_type_subclass(m, "SequenceType", rtgTypeIsASequence)
.def_classmethod(
"get",
[](py::object cls, MlirContext ctxt) {
return cls(rtgSequenceTypeGet(ctxt));
},
py::arg("self"), py::arg("ctxt") = nullptr);

mlir_type_subclass(m, "SetType", rtgTypeIsASet)
.def_classmethod(
"get",
[](py::object cls, MlirType elementType) {
return cls(rtgSetTypeGet(elementType));
},
py::arg("self"), py::arg("element_type"));

mlir_type_subclass(m, "DictType", rtgTypeIsADict)
.def_classmethod(
"get",
[](py::object cls, MlirContext ctxt, py::list entryNames,
py::list entryTypes) {
std::vector<MlirAttribute> names;
std::vector<MlirType> types;
for (auto type : entryNames)
names.push_back(type.cast<MlirAttribute>());
for (auto type : entryTypes)
types.push_back(type.cast<MlirType>());
assert(names.size() == types.size() &&
"number of entry names and entry types must match");
return cls(
rtgDictTypeGet(ctxt, types.size(), names.data(), types.data()));
},
py::arg("self"), py::arg("ctxt") = nullptr,
py::arg("entry_names") = py::list(),
py::arg("entry_types") = py::list());
}
34 changes: 34 additions & 0 deletions lib/Bindings/Python/RTGTestModule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===- RTGTestModule.cpp - RTGTest API pybind module ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "CIRCTModules.h"

#include "circt-c/Dialect/RTGTest.h"

#include "mlir/Bindings/Python/PybindAdaptors.h"

#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
namespace py = pybind11;

using namespace circt;
using namespace mlir::python::adaptors;

/// Populate the rtgtest python module.
void circt::python::populateDialectRTGTestSubmodule(py::module &m) {
m.doc() = "RTGTest dialect Python native extension";

mlir_type_subclass(m, "CPUType", rtgtestTypeIsACPU)
.def_classmethod(
"get",
[](py::object cls, MlirContext ctxt) {
return cls(rtgtestCPUTypeGet(ctxt));
},
py::arg("self"), py::arg("ctxt") = nullptr);
}
Loading

0 comments on commit c2d53a7

Please sign in to comment.