Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RTG] Add BagType and operations #7887

Open
wants to merge 1 commit into
base: maerhart-rtg-python-bindings
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,71 @@ def SetDifferenceOp : RTGOp<"set_difference", [
}];
}

//===- Bag Operations ------------------------------------------------------===//

def BagCreateOp : RTGOp<"bag_create", [Pure, SameVariadicOperandSize]> {
let summary = "constructs a bag";
let description = [{
This operation constructs a bag with the provided values and associated
multiples. This means the bag constructed in the following example contains
two of each `%arg0` and `%arg0` (`{%arg0, %arg0, %arg1, %arg1}`).

```mlir
%0 = arith.constant 2 : index
%1 = rtg.bag_create (%0 x %arg0, %0 x %arg1) : i32
```
}];

let arguments = (ins Variadic<AnyType>:$elements,
Variadic<Index>:$multiples);
let results = (outs BagType:$bag);

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

def BagSelectRandomOp : RTGOp<"bag_select_random", [
Pure,
TypesMatchWith<"output must be element type of input bag", "bag", "output",
"llvm::cast<rtg::BagType>($_self).getElementType()">
]> {
let summary = "select a random element from the bag";
let description = [{
This operation returns an element from the bag selected uniformely at
random. Therefore, the number of duplicates of each element can be used to
bias the distribution.
If the bag does not contain any elements, the behavior of this operation is
undefined.
}];

let arguments = (ins BagType:$bag);
let results = (outs AnyType:$output);

let assemblyFormat = "$bag `:` qualified(type($bag)) attr-dict";
}

def BagDifferenceOp : RTGOp<"bag_difference", [
Pure,
AllTypesMatch<["original", "diff", "output"]>
]> {
let summary = "computes the difference of two bags";
let description = [{
For each element the resulting bag will have as many fewer than the
'original' bag as there are in the 'diff' bag. However, if the 'inf'
attribute is attached, all elements of that kind will be removed (i.e., it
is assumed the 'diff' bag has infinitely many copies of each element).
}];

let arguments = (ins BagType:$original,
BagType:$diff,
UnitAttr:$inf);
let results = (outs BagType:$output);

let assemblyFormat = [{
$original `,` $diff (`inf` $inf^)? `:` qualified(type($output)) attr-dict
}];
}

//===- Test Specification Operations --------------------------------------===//

def TestOp : RTGOp<"test", [
Expand Down
17 changes: 17 additions & 0 deletions include/circt/Dialect/RTG/IR/RTGTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,27 @@ def SetType : RTGTypeDef<"Set"> {
let assemblyFormat = "`<` $elementType `>`";
}

def BagType : RTGTypeDef<"Bag"> {
let summary = "a bag of values";
let description = [{
This type represents a standard bag/multiset datastructure. It does not make
any assumptions about the underlying implementation.
}];

let parameters = (ins "::mlir::Type":$elementType);

let mnemonic = "bag";
let assemblyFormat = "`<` $elementType `>`";
}

class SetTypeOf<Type elementType> : ContainerType<
elementType, SetType.predicate,
"llvm::cast<rtg::SetType>($_self).getElementType()", "set">;

class BagTypeOf<Type elementType> : ContainerType<
elementType, BagType.predicate,
"llvm::cast<rtg::BagType>($_self).getElementType()", "bag">;

def DictType : RTGTypeDef<"Dict"> {
let summary = "a dictionary";
let description = [{
Expand Down
16 changes: 11 additions & 5 deletions include/circt/Dialect/RTG/IR/RTGVisitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ class RTGOpVisitor {
auto *thisCast = static_cast<ConcreteType *>(this);
return TypeSwitch<Operation *, ResultType>(op)
.template Case<SequenceOp, SequenceClosureOp, SetCreateOp,
SetSelectRandomOp, SetDifferenceOp, InvokeSequenceOp,
TestOp, TargetOp, YieldOp>([&](auto expr) -> ResultType {
return thisCast->visitOp(expr, args...);
})
SetSelectRandomOp, SetDifferenceOp, TestOp,
InvokeSequenceOp, BagCreateOp, BagSelectRandomOp,
BagDifferenceOp, TargetOp, YieldOp>(
[&](auto expr) -> ResultType {
return thisCast->visitOp(expr, args...);
})
.template Case<ContextResourceOpInterface>(
[&](auto expr) -> ResultType {
return thisCast->visitContextResourceOp(expr, args...);
Expand Down Expand Up @@ -79,6 +81,9 @@ class RTGOpVisitor {
HANDLE(SetCreateOp, Unhandled);
HANDLE(SetSelectRandomOp, Unhandled);
HANDLE(SetDifferenceOp, Unhandled);
HANDLE(BagCreateOp, Unhandled);
HANDLE(BagSelectRandomOp, Unhandled);
HANDLE(BagDifferenceOp, Unhandled);
HANDLE(TestOp, Unhandled);
HANDLE(TargetOp, Unhandled);
HANDLE(YieldOp, Unhandled);
Expand All @@ -93,7 +98,7 @@ class RTGTypeVisitor {
ResultType dispatchTypeVisitor(Type type, ExtraArgs... args) {
auto *thisCast = static_cast<ConcreteType *>(this);
return TypeSwitch<Type, ResultType>(type)
.template Case<SequenceType, SetType, DictType>(
.template Case<SequenceType, SetType, BagType, DictType>(
[&](auto expr) -> ResultType {
return thisCast->visitType(expr, args...);
})
Expand Down Expand Up @@ -138,6 +143,7 @@ class RTGTypeVisitor {

HANDLE(SequenceType, Unhandled);
HANDLE(SetType, Unhandled);
HANDLE(BagType, Unhandled);
HANDLE(DictType, Unhandled);
#undef HANDLE
};
Expand Down
72 changes: 72 additions & 0 deletions lib/Dialect/RTG/IR/RTGOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,78 @@ LogicalResult SetCreateOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// BagCreateOp
//===----------------------------------------------------------------------===//

ParseResult BagCreateOp::parse(OpAsmParser &parser, OperationState &result) {
llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> elementOperands,
multipleOperands;
Type elemType;

if (!parser.parseOptionalLParen()) {
while (true) {
OpAsmParser::UnresolvedOperand elementOperand, multipleOperand;
if (parser.parseOperand(multipleOperand) || parser.parseKeyword("x") ||
parser.parseOperand(elementOperand))
return failure();

elementOperands.push_back(elementOperand);
multipleOperands.push_back(multipleOperand);

if (parser.parseOptionalComma()) {
if (parser.parseRParen())
return failure();
break;
}
}
}

if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseType(elemType))
return failure();

result.addTypes({BagType::get(result.getContext(), elemType)});

for (auto operand : elementOperands)
if (parser.resolveOperand(operand, elemType, result.operands))
return failure();

for (auto operand : multipleOperands)
if (parser.resolveOperand(operand, IndexType::get(result.getContext()),
result.operands))
return failure();

return success();
}

void BagCreateOp::print(OpAsmPrinter &p) {
p << " ";
if (!getElements().empty())
p << "(";
llvm::interleaveComma(llvm::zip(getElements(), getMultiples()), p,
[&](auto elAndMultiple) {
auto [el, multiple] = elAndMultiple;
p << multiple << " x " << el;
});
if (!getElements().empty())
p << ")";

p.printOptionalAttrDict((*this)->getAttrs());
p << " : " << getBag().getType().getElementType();
}

LogicalResult BagCreateOp::verify() {
if (!llvm::all_equal(getElements().getTypes()))
return emitOpError() << "types of all elements must match";

if (getElements().size() > 0)
if (getElements()[0].getType() != getBag().getType().getElementType())
return emitOpError() << "operand types must match bag element type";

return success();
}

//===----------------------------------------------------------------------===//
// TestOp
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 15 additions & 0 deletions test/Dialect/RTG/IR/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ func.func @sets(%arg0: i32, %arg1: i32) {
return
}

// CHECK-LABEL: @bags
rtg.sequence @bags {
^bb0(%arg0: i32, %arg1: i32, %arg2: index):
// CHECK: [[BAG:%.+]] = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32
// CHECK: [[R:%.+]] = rtg.bag_select_random [[BAG]] : !rtg.bag<i32>
// CHECK: [[EMPTY:%.+]] = rtg.bag_create : i32
// CHECK: rtg.bag_difference [[BAG]], [[EMPTY]] : !rtg.bag<i32>
// CHECK: rtg.bag_difference [[BAG]], [[EMPTY]] inf : !rtg.bag<i32>
%bag = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32
%r = rtg.bag_select_random %bag : !rtg.bag<i32>
%empty = rtg.bag_create : i32
%diff = rtg.bag_difference %bag, %empty : !rtg.bag<i32>
%diff2 = rtg.bag_difference %bag, %empty inf : !rtg.bag<i32>
}

// CHECK-LABEL: rtg.target @empty_target : !rtg.dict<> {
// CHECK-NOT: rtg.yield
rtg.target @empty_target : !rtg.dict<> {
Expand Down
16 changes: 16 additions & 0 deletions test/Dialect/RTG/IR/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,19 @@ rtg.test @test : !rtg.dict<a: i32> {
rtg.test @test : !rtg.dict<a: i32, a: i32> {
^bb0(%arg0: i32, %arg1: i32):
}

// -----

rtg.sequence @seq {
^bb0(%arg0: i32, %arg1: i64, %arg2: index):
// expected-error @below {{types of all elements must match}}
"rtg.bag_create"(%arg0, %arg1, %arg2, %arg2){} : (i32, i64, index, index) -> !rtg.bag<i32>
}

// -----

rtg.sequence @seq {
^bb0(%arg0: i64, %arg1: i64, %arg2: index):
// expected-error @below {{operand types must match bag element type}}
"rtg.bag_create"(%arg0, %arg1, %arg2, %arg2){} : (i64, i64, index, index) -> !rtg.bag<i32>
}
Loading