diff --git a/include/circt/Dialect/RTG/IR/RTGOps.td b/include/circt/Dialect/RTG/IR/RTGOps.td index 95985be015cb..6b85fb76073c 100644 --- a/include/circt/Dialect/RTG/IR/RTGOps.td +++ b/include/circt/Dialect/RTG/IR/RTGOps.td @@ -141,6 +141,71 @@ def SetDifferenceOp : RTGOp<"set_difference", [ }]; } +//===- Bag Operations ------------------------------------------------------===// + +def BagCreateOp : RTGOp<"bag_create", [Pure, SameVariadicOperandSize]> { + let summary = "constructs a bag"; + let description = [{ + This operation constructs a bag with the provided values and associated + multiples. This means the bag constructed in the following example contains + two of each `%arg0` and `%arg0` (`{%arg0, %arg0, %arg1, %arg1}`). + + ```mlir + %0 = arith.constant 2 : index + %1 = rtg.bag_create (%0 x %arg0, %0 x %arg1) : i32 + ``` + }]; + + let arguments = (ins Variadic:$elements, + Variadic:$multiples); + let results = (outs BagType:$bag); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def BagSelectRandomOp : RTGOp<"bag_select_random", [ + Pure, + TypesMatchWith<"output must be element type of input bag", "bag", "output", + "llvm::cast($_self).getElementType()"> +]> { + let summary = "select a random element from the bag"; + let description = [{ + This operation returns an element from the bag selected uniformely at + random. Therefore, the number of duplicates of each element can be used to + bias the distribution. + If the bag does not contain any elements, the behavior of this operation is + undefined. + }]; + + let arguments = (ins BagType:$bag); + let results = (outs AnyType:$output); + + let assemblyFormat = "$bag `:` qualified(type($bag)) attr-dict"; +} + +def BagDifferenceOp : RTGOp<"bag_difference", [ + Pure, + AllTypesMatch<["original", "diff", "output"]> +]> { + let summary = "computes the difference of two bags"; + let description = [{ + For each element the resulting bag will have as many fewer than the + 'original' bag as there are in the 'diff' bag. However, if the 'inf' + attribute is attached, all elements of that kind will be removed (i.e., it + is assumed the 'diff' bag has infinitely many copies of each element). + }]; + + let arguments = (ins BagType:$original, + BagType:$diff, + UnitAttr:$inf); + let results = (outs BagType:$output); + + let assemblyFormat = [{ + $original `,` $diff (`inf` $inf^)? `:` qualified(type($output)) attr-dict + }]; +} + //===- Test Specification Operations --------------------------------------===// def TestOp : RTGOp<"test", [ diff --git a/include/circt/Dialect/RTG/IR/RTGTypes.td b/include/circt/Dialect/RTG/IR/RTGTypes.td index b5afba75318b..99e68b1cc657 100644 --- a/include/circt/Dialect/RTG/IR/RTGTypes.td +++ b/include/circt/Dialect/RTG/IR/RTGTypes.td @@ -43,10 +43,27 @@ def SetType : RTGTypeDef<"Set"> { let assemblyFormat = "`<` $elementType `>`"; } +def BagType : RTGTypeDef<"Bag"> { + let summary = "a bag of values"; + let description = [{ + This type represents a standard bag/multiset datastructure. It does not make + any assumptions about the underlying implementation. + }]; + + let parameters = (ins "::mlir::Type":$elementType); + + let mnemonic = "bag"; + let assemblyFormat = "`<` $elementType `>`"; +} + class SetTypeOf : ContainerType< elementType, SetType.predicate, "llvm::cast($_self).getElementType()", "set">; +class BagTypeOf : ContainerType< + elementType, BagType.predicate, + "llvm::cast($_self).getElementType()", "bag">; + def DictType : RTGTypeDef<"Dict"> { let summary = "a dictionary"; let description = [{ diff --git a/include/circt/Dialect/RTG/IR/RTGVisitors.h b/include/circt/Dialect/RTG/IR/RTGVisitors.h index de11a8ee65e3..bcafb7cb0f82 100644 --- a/include/circt/Dialect/RTG/IR/RTGVisitors.h +++ b/include/circt/Dialect/RTG/IR/RTGVisitors.h @@ -31,10 +31,12 @@ class RTGOpVisitor { auto *thisCast = static_cast(this); return TypeSwitch(op) .template Case([&](auto expr) -> ResultType { - return thisCast->visitOp(expr, args...); - }) + SetSelectRandomOp, SetDifferenceOp, TestOp, + InvokeSequenceOp, BagCreateOp, BagSelectRandomOp, + BagDifferenceOp, TargetOp, YieldOp>( + [&](auto expr) -> ResultType { + return thisCast->visitOp(expr, args...); + }) .template Case( [&](auto expr) -> ResultType { return thisCast->visitContextResourceOp(expr, args...); @@ -79,6 +81,9 @@ class RTGOpVisitor { HANDLE(SetCreateOp, Unhandled); HANDLE(SetSelectRandomOp, Unhandled); HANDLE(SetDifferenceOp, Unhandled); + HANDLE(BagCreateOp, Unhandled); + HANDLE(BagSelectRandomOp, Unhandled); + HANDLE(BagDifferenceOp, Unhandled); HANDLE(TestOp, Unhandled); HANDLE(TargetOp, Unhandled); HANDLE(YieldOp, Unhandled); @@ -93,7 +98,7 @@ class RTGTypeVisitor { ResultType dispatchTypeVisitor(Type type, ExtraArgs... args) { auto *thisCast = static_cast(this); return TypeSwitch(type) - .template Case( + .template Case( [&](auto expr) -> ResultType { return thisCast->visitType(expr, args...); }) @@ -138,6 +143,7 @@ class RTGTypeVisitor { HANDLE(SequenceType, Unhandled); HANDLE(SetType, Unhandled); + HANDLE(BagType, Unhandled); HANDLE(DictType, Unhandled); #undef HANDLE }; diff --git a/lib/Dialect/RTG/IR/RTGOps.cpp b/lib/Dialect/RTG/IR/RTGOps.cpp index 3df1bed07e7f..ca7102e99b3b 100644 --- a/lib/Dialect/RTG/IR/RTGOps.cpp +++ b/lib/Dialect/RTG/IR/RTGOps.cpp @@ -78,6 +78,78 @@ LogicalResult SetCreateOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// BagCreateOp +//===----------------------------------------------------------------------===// + +ParseResult BagCreateOp::parse(OpAsmParser &parser, OperationState &result) { + llvm::SmallVector elementOperands, + multipleOperands; + Type elemType; + + if (!parser.parseOptionalLParen()) { + while (true) { + OpAsmParser::UnresolvedOperand elementOperand, multipleOperand; + if (parser.parseOperand(multipleOperand) || parser.parseKeyword("x") || + parser.parseOperand(elementOperand)) + return failure(); + + elementOperands.push_back(elementOperand); + multipleOperands.push_back(multipleOperand); + + if (parser.parseOptionalComma()) { + if (parser.parseRParen()) + return failure(); + break; + } + } + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.parseType(elemType)) + return failure(); + + result.addTypes({BagType::get(result.getContext(), elemType)}); + + for (auto operand : elementOperands) + if (parser.resolveOperand(operand, elemType, result.operands)) + return failure(); + + for (auto operand : multipleOperands) + if (parser.resolveOperand(operand, IndexType::get(result.getContext()), + result.operands)) + return failure(); + + return success(); +} + +void BagCreateOp::print(OpAsmPrinter &p) { + p << " "; + if (!getElements().empty()) + p << "("; + llvm::interleaveComma(llvm::zip(getElements(), getMultiples()), p, + [&](auto elAndMultiple) { + auto [el, multiple] = elAndMultiple; + p << multiple << " x " << el; + }); + if (!getElements().empty()) + p << ")"; + + p.printOptionalAttrDict((*this)->getAttrs()); + p << " : " << getBag().getType().getElementType(); +} + +LogicalResult BagCreateOp::verify() { + if (!llvm::all_equal(getElements().getTypes())) + return emitOpError() << "types of all elements must match"; + + if (getElements().size() > 0) + if (getElements()[0].getType() != getBag().getType().getElementType()) + return emitOpError() << "operand types must match bag element type"; + + return success(); +} + //===----------------------------------------------------------------------===// // TestOp //===----------------------------------------------------------------------===// diff --git a/test/Dialect/RTG/IR/basic.mlir b/test/Dialect/RTG/IR/basic.mlir index f0070e6726d7..847388a5eb0f 100644 --- a/test/Dialect/RTG/IR/basic.mlir +++ b/test/Dialect/RTG/IR/basic.mlir @@ -39,6 +39,21 @@ func.func @sets(%arg0: i32, %arg1: i32) { return } +// CHECK-LABEL: @bags +rtg.sequence @bags { +^bb0(%arg0: i32, %arg1: i32, %arg2: index): + // CHECK: [[BAG:%.+]] = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32 + // CHECK: [[R:%.+]] = rtg.bag_select_random [[BAG]] : !rtg.bag + // CHECK: [[EMPTY:%.+]] = rtg.bag_create : i32 + // CHECK: rtg.bag_difference [[BAG]], [[EMPTY]] : !rtg.bag + // CHECK: rtg.bag_difference [[BAG]], [[EMPTY]] inf : !rtg.bag + %bag = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32 + %r = rtg.bag_select_random %bag : !rtg.bag + %empty = rtg.bag_create : i32 + %diff = rtg.bag_difference %bag, %empty : !rtg.bag + %diff2 = rtg.bag_difference %bag, %empty inf : !rtg.bag +} + // CHECK-LABEL: rtg.target @empty_target : !rtg.dict<> { // CHECK-NOT: rtg.yield rtg.target @empty_target : !rtg.dict<> { diff --git a/test/Dialect/RTG/IR/errors.mlir b/test/Dialect/RTG/IR/errors.mlir index f4475e7820ea..c212fbdb8d07 100644 --- a/test/Dialect/RTG/IR/errors.mlir +++ b/test/Dialect/RTG/IR/errors.mlir @@ -36,3 +36,19 @@ rtg.test @test : !rtg.dict { rtg.test @test : !rtg.dict { ^bb0(%arg0: i32, %arg1: i32): } + +// ----- + +rtg.sequence @seq { +^bb0(%arg0: i32, %arg1: i64, %arg2: index): + // expected-error @below {{types of all elements must match}} + "rtg.bag_create"(%arg0, %arg1, %arg2, %arg2){} : (i32, i64, index, index) -> !rtg.bag +} + +// ----- + +rtg.sequence @seq { +^bb0(%arg0: i64, %arg1: i64, %arg2: index): + // expected-error @below {{operand types must match bag element type}} + "rtg.bag_create"(%arg0, %arg1, %arg2, %arg2){} : (i64, i64, index, index) -> !rtg.bag +}