Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RTG] Add BagType CAPI and Python Bindings #7888

Open
wants to merge 1 commit into
base: maerhart-rtg-bags
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/circt-c/Dialect/RTG.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ MLIR_CAPI_EXPORTED bool rtgTypeIsASet(MlirType type);
/// Creates an RTG set type in the context.
MLIR_CAPI_EXPORTED MlirType rtgSetTypeGet(MlirType elementType);

/// If the type is an RTG bag.
MLIR_CAPI_EXPORTED bool rtgTypeIsABag(MlirType type);

/// Creates an RTG bag type in the context.
MLIR_CAPI_EXPORTED MlirType rtgBagTypeGet(MlirType elementType);

/// If the type is an RTG dict.
MLIR_CAPI_EXPORTED bool rtgTypeIsADict(MlirType type);

Expand Down
17 changes: 16 additions & 1 deletion integration_test/Bindings/Python/dialects/rtg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import circt

from circt.dialects import rtg, rtgtest
from circt.ir import Context, Location, Module, InsertionPoint, Block, StringAttr, TypeAttr
from circt.ir import Context, Location, Module, InsertionPoint, Block, StringAttr, TypeAttr, IndexType
from circt.passmanager import PassManager
from circt import rtgtool_support as rtgtool

Expand Down Expand Up @@ -67,3 +67,18 @@
# CHECK: rtg.test @test_name : !rtg.dict<> {
# CHECK-NEXT: }
print(m)

with Context() as ctx, Location.unknown():
circt.register_dialects(ctx)
m = Module.create()
with InsertionPoint(m.body):
indexTy = IndexType.get()
sequenceTy = rtg.SequenceType.get()
setTy = rtg.SetType.get(indexTy)
bagTy = rtg.BagType.get(indexTy)
seq = rtg.SequenceOp('seq')
Block.create_at_start(seq.bodyRegion, [sequenceTy, setTy, bagTy])

# CHECK: rtg.sequence @seq
# CHECK: (%{{.*}}: !rtg.sequence, %{{.*}}: !rtg.set<index>, %{{.*}}: !rtg.bag<index>):
print(m)
8 changes: 8 additions & 0 deletions lib/Bindings/Python/RTGModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ void circt::python::populateDialectRTGSubmodule(py::module &m) {
},
py::arg("self"), py::arg("element_type"));

mlir_type_subclass(m, "BagType", rtgTypeIsABag)
.def_classmethod(
"get",
[](py::object cls, MlirType elementType) {
return cls(rtgBagTypeGet(elementType));
},
py::arg("self"), py::arg("element_type"));

mlir_type_subclass(m, "DictType", rtgTypeIsADict)
.def_classmethod(
"get",
Expand Down
10 changes: 10 additions & 0 deletions lib/CAPI/Dialect/RTG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ MlirType rtgSetTypeGet(MlirType elementType) {
return wrap(SetType::get(ty.getContext(), ty));
}

// BagType
//===----------------------------------------------------------------------===//

bool rtgTypeIsABag(MlirType type) { return isa<BagType>(unwrap(type)); }

MlirType rtgBagTypeGet(MlirType elementType) {
auto ty = unwrap(elementType);
return wrap(BagType::get(ty.getContext(), ty));
}

// DictType
//===----------------------------------------------------------------------===//

Expand Down
11 changes: 11 additions & 0 deletions test/CAPI/rtg.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ static void testSetType(MlirContext ctx) {
mlirTypeDump(setTy);
}

static void testBagType(MlirContext ctx) {
MlirType elTy = mlirIntegerTypeGet(ctx, 32);
MlirType bagTy = rtgBagTypeGet(elTy);

// CHECK: is_bag
fprintf(stderr, rtgTypeIsABag(bagTy) ? "is_bag\n" : "isnot_bag\n");
// CHECK: !rtg.bag<i32>
mlirTypeDump(bagTy);
}

static void testDictType(MlirContext ctx) {
MlirType elTy = mlirIntegerTypeGet(ctx, 32);
MlirAttribute name0 =
Expand Down Expand Up @@ -62,6 +72,7 @@ int main(int argc, char **argv) {

testSequenceType(ctx);
testSetType(ctx);
testBagType(ctx);
testDictType(ctx);

mlirContextDestroy(ctx);
Expand Down
Loading