Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Iterators] Implement merge-join op. #682

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,35 @@ def Iterators_MapOp : Iterators_Op<"map",
}];
}

def Iterators_MergeJoinOp : Iterators_Op<"mergejoin", [
PredOpTrait<"the element type of the result stream must be a tuple of the "
"two respective element types of the two input streams",
CPred<[{
$result.getType().cast<StreamType>().getElementType() ==
TupleType::get(
$result.getContext(),
TypeRange{$lhs.getType().cast<StreamType>().getElementType(),
$rhs.getType().cast<StreamType>().getElementType()})}]>>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Join two sorted streams of tuples on their first element.";
let description = [{
}];
let arguments = (ins Iterators_Stream:$lhs, Iterators_Stream:$rhs);
let results = (outs Iterators_StreamOf<AnyTuple>:$result);
let assemblyFormat = [{
$lhs `and` $rhs attr-dict `:` type($result)
custom<JoinTypes>(type($lhs), type($rhs), ref(type($result)))
}];
let extraClassDefinition = [{
/// Implement OpAsmOpInterface.
void $cppClass::getAsmResultNames(
llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
setNameFn(getResult(), "joined");
}
}];
}

def Iterators_ReduceOp : Iterators_Op<"reduce",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,26 @@ StateTypeComputer::operator()(MapOp op,
return StateType::get(context, {upstreamStateTypes[0]});
}

/// The state of MergeJoinOp consists of (1) the states of the two upstream ops,
/// (2) the last element successfully consumed from each of the two upstream
/// ops if any, and (3) two Booleans indicating whether these elements exist,
/// respectively.
template <>
StateType
StateTypeComputer::operator()(MergeJoinOp op,
llvm::SmallVector<StateType> upstreamStateTypes) {
MLIRContext *context = op->getContext();
StateType lhsStateType = upstreamStateTypes[0];
StateType rhsStateType = upstreamStateTypes[1];
auto lhsStreamType = op.getLhs().getType().cast<StreamType>();
auto rhsStreamType = op.getRhs().getType().cast<StreamType>();
Type lhsElementType = lhsStreamType.getElementType();
Type rhsElementType = rhsStreamType.getElementType();
Type i1 = IntegerType::get(context, /*width=*/1);
return StateType::get(context, {lhsStateType, i1, lhsElementType,
rhsStateType, i1, rhsElementType});
}

/// The state of ReduceOp only consists of the state of its upstream iterator,
/// i.e., the state of the iterator that produces its input stream.
template <>
Expand Down Expand Up @@ -183,6 +203,7 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis(
ConstantStreamOp,
FilterOp,
MapOp,
MergeJoinOp,
ReduceOp,
TabularViewToStreamOp,
ValueToStreamOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ buildNextBody(FilterOp op, OpBuilder &builder, Value initialState,
// If we got an element, apply predicate.
auto ifOp = b.create<scf::IfOp>(
/*condition=*/hasNext,
/*ifBuilder=*/
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);

Expand Down Expand Up @@ -783,7 +783,7 @@ buildNextBody(MapOp op, OpBuilder &builder, Value initialState,
// If we got an element, apply map function.
auto ifOp = b.create<scf::IfOp>(
/*condition=*/hasNext,
/*ifBuilder=*/
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
// Apply map function.
ImplicitLocOpBuilder b(loc, builder);
Expand Down Expand Up @@ -864,6 +864,237 @@ static Value buildStateCreation(MapOp op, MapOp::Adaptor adaptor,
return b.create<CreateStateOp>(stateType, upstreamState);
}

//===----------------------------------------------------------------------===//
// MapOp.
//===----------------------------------------------------------------------===//

/// XXX
static Value buildOpenBody(MergeJoinOp op, OpBuilder &builder,
Value initialState,
ArrayRef<IteratorInfo> upstreamInfos) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);

// Open both upstream states.
Value state = initialState;
for (auto i : {0, 1}) {
Type upstreamStateType = upstreamInfos[i].stateType;
IntegerAttr fieldIndex = b.getIndexAttr(i * 3);

// Extract upstream state.
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, initialState, fieldIndex);

// Call Open on upstream.
SymbolRefAttr openFunc = upstreamInfos[i].openFunc;
auto callOp = b.create<func::CallOp>(openFunc, upstreamStateType,
initialUpstreamState);

// Update upstream state.
Value updatedUpstreamState = callOp->getResult(0);
state = b.create<iterators::InsertValueOp>(initialState, fieldIndex,
updatedUpstreamState);
}

return state;
}

/// XXX
static llvm::SmallVector<Value, 4>
buildNextBody(MergeJoinOp op, OpBuilder &builder, Value initialState,
ArrayRef<IteratorInfo> upstreamInfos, Type elementType) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);
MLIRContext *context = b.getContext();
Type i1 = IntegerType::get(context, /*width=*/1);

// Determine various derived types.
auto stateType = initialState.getType().cast<StateType>();
Type lhsStateType = upstreamInfos[0].stateType;
Type rhsStateType = upstreamInfos[1].stateType;
auto lhsStreamType = op.getLhs().getType().cast<StreamType>();
auto rhsStreamType = op.getRhs().getType().cast<StreamType>();
Type lhsElementType = lhsStreamType.getElementType();
Type rhsElementType = rhsStreamType.getElementType();

TypeRange stateTypes{lhsStateType, rhsStateType};
TypeRange elementTypes{lhsElementType, rhsElementType};

SymbolRefAttr lhsNextFunc = upstreamInfos[0].nextFunc;
SymbolRefAttr rhsNextFunc = upstreamInfos[1].nextFunc;
SmallVector<SymbolRefAttr, 2> nextFuncs = {lhsNextFunc, rhsNextFunc};

// Fetch initial upstream elements if required.
SmallVector<Value, 2> upstreamStates(2);
SmallVector<Value, 2> upstreamHasElements(2);
SmallVector<Value, 2> upstreamElements(2);
for (auto i : {0, 1}) {
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
lhsStateType, initialState, b.getIndexAttr(i * 3));
Value initialHasElement = b.create<iterators::ExtractValueOp>(
i1, initialState, b.getIndexAttr(i * 3 + 1));
auto ifOp = b.create<scf::IfOp>(
/*condition=*/initialHasElement,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);
// The element stored in the state is valid, so take that and don't
// modify the corresponding upstream state.
Value initialElement = b.create<iterators::ExtractValueOp>(
elementTypes[i], initialState, b.getIndexAttr(i * 3 + 2));
b.create<scf::YieldOp>(ValueRange{initialUpstreamState,
initialHasElement, initialElement});
}, /*elseBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);
// The element stored in the state is not valid, so fetch a new one by
// calling Next and return the result of that.
TypeRange resultTypes{stateTypes[i], i1, elementTypes[i]};
auto callOp = b.create<func::CallOp>(nextFuncs[i], resultTypes,
initialUpstreamState);
b.create<scf::YieldOp>(callOp->getResults());
});
upstreamStates[i] = ifOp->getResult(0);
upstreamHasElements[i] = ifOp->getResult(1);
upstreamElements[i] = ifOp->getResult(2);
}

// Main while loop looking for a match.
ValueRange whileInputs // (force formatting)
{upstreamStates[0], upstreamHasElements[0], upstreamElements[0],
upstreamStates[1], upstreamHasElements[1], upstreamElements[1]};
scf::WhileOp whileOp = b.create<scf::WhileOp>(
/*resultTypes=*/whileInputs.getTypes(), whileInputs,
/*beforeBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange args) {
ImplicitLocOpBuilder b(loc, builder);
Value lhsHasElement = args[1];
Value rhsHasElement = args[4];
Value bothSidesHaveElement =
b.create<arith::AndIOp>(lhsHasElement, rhsHasElement);

// XXX: make extendible:
Value lhsElement = args[2];
Value rhsElement = args[5];
Value elementsNotEqual = b.create<arith::CmpIOp>(
arith::CmpIPredicate::ne, lhsElement, rhsElement);

// If the two elements are valid but not the same, we need to continue
// searching for a match.
Value continueLoop =
b.create<arith::AndIOp>(bothSidesHaveElement, elementsNotEqual);
b.create<scf::ConditionOp>(continueLoop, args);
},
/*afterBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange args) {
ImplicitLocOpBuilder b(loc, builder);
Value lhsElement = args[2];
Value rhsElement = args[5];
Value lhsIsSmaller = b.create<arith::CmpIOp>(arith::CmpIPredicate::slt,
lhsElement, rhsElement);
auto ifOp = b.create<scf::IfOp>(
/*condition=*/lhsIsSmaller,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);
TypeRange resultTypes{stateTypes[0], i1, elementTypes[0]};
auto callOp =
b.create<func::CallOp>(nextFuncs[0], resultTypes, args[0]);
Value updatedLhsState = callOp->getResult(0);
Value updatedLhsHasElement = callOp->getResult(1);
Value updatedLhsElement = callOp->getResult(2);
b.create<scf::YieldOp>(
ValueRange{updatedLhsState, updatedLhsHasElement,
updatedLhsElement, args[3], args[4], args[5]});
},
/*elseBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);
TypeRange resultTypes{stateTypes[1], i1, elementTypes[1]};
auto callOp =
b.create<func::CallOp>(nextFuncs[1], resultTypes, args[3]);
Value updatedRhsState = callOp->getResult(0);
Value updatedRhsHasElement = callOp->getResult(1);
Value updatedRhsElement = callOp->getResult(2);
b.create<scf::YieldOp>(
ValueRange{args[0], args[1], args[2], updatedRhsState,
updatedRhsHasElement, updatedRhsElement});
});

b.create<scf::YieldOp>(ifOp.getResults());
});

Value finalLhsState = whileOp->getResult(0);
Value finalRhsState = whileOp->getResult(3);
Value finalLhsHasElement = whileOp->getResult(1);
Value finalRhsHasElement = whileOp->getResult(4);
Value finalLhsElement = whileOp->getResult(2);
Value finalRhsElement = whileOp->getResult(5);

// Update state and compute return values.
Value constFalse = b.create<arith::ConstantIntOp>(/*value=*/0, /*width=*/1);
Value updatedState = b.create<CreateStateOp>(
stateType, ValueRange{finalLhsState, constFalse, finalLhsElement,
finalRhsState, constFalse, finalRhsElement});

Value nextElement = b.create<tuple::FromElementsOp>(
elementType, ValueRange{finalLhsElement, finalRhsElement});
Value hasNext =
b.create<arith::AndIOp>(finalLhsHasElement, finalRhsHasElement);

return {updatedState, hasNext, nextElement};
}

/// XXX
static Value buildCloseBody(MergeJoinOp op, OpBuilder &builder,
Value initialState,
ArrayRef<IteratorInfo> upstreamInfos) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);

// Close both upstream states.
Value state = initialState;
for (auto i : {0, 1}) {
Type upstreamStateType = upstreamInfos[i].stateType;
IntegerAttr fieldIndex = b.getIndexAttr(i * 3);

// Extract upstream state.
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, initialState, fieldIndex);

// Call Close on upstream.
SymbolRefAttr closeFunc = upstreamInfos[i].closeFunc;
auto callOp = b.create<func::CallOp>(closeFunc, upstreamStateType,
initialUpstreamState);

// Update upstream state.
Value updatedUpstreamState = callOp->getResult(0);
state = b.create<iterators::InsertValueOp>(initialState, fieldIndex,
updatedUpstreamState);
}

return state;
}

/// XXX
static Value buildStateCreation(MergeJoinOp op, MergeJoinOp::Adaptor adaptor,
OpBuilder &builder, StateType stateType) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);
Value lhsState = adaptor.getLhs();
Value rhsState = adaptor.getRhs();
Value constFalse = b.create<arith::ConstantIntOp>(/*value=*/0, /*width=*/1);
auto lhsStreamType = op.getLhs().getType().cast<StreamType>();
auto rhsStreamType = op.getRhs().getType().cast<StreamType>();
Type lhsElementType = lhsStreamType.getElementType();
Type rhsElementType = rhsStreamType.getElementType();
Value lhsUndefElement = b.create<LLVM::UndefOp>(lhsElementType);
Value rhsUndefElement = b.create<LLVM::UndefOp>(rhsElementType);
return b.create<CreateStateOp>(
stateType, ValueRange{lhsState, constFalse, lhsUndefElement, // (force nl)
rhsState, constFalse, rhsUndefElement});
}

//===----------------------------------------------------------------------===//
// ReduceOp.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -955,7 +1186,7 @@ buildNextBody(ReduceOp op, OpBuilder &builder, Value initialState,
Value firstHasNext = firstNextCall->getResult(1);
auto ifOp = b.create<scf::IfOp>(
/*condition=*/firstHasNext,
/*ifBuilder=*/
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);

Expand Down Expand Up @@ -1450,7 +1681,6 @@ buildNextBody(ZipOp op, OpBuilder &builder, Value initialState,
/// into %initialState[0] : !iterators.state<upstream_state_type>
static Value buildCloseBody(ZipOp op, OpBuilder &builder, Value initialState,
ArrayRef<IteratorInfo> upstreamInfos) {

Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);

Expand Down Expand Up @@ -1546,6 +1776,7 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder,
ConstantStreamOp,
FilterOp,
MapOp,
MergeJoinOp,
ReduceOp,
TabularViewToStreamOp,
ValueToStreamOp,
Expand All @@ -1566,6 +1797,7 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState,
ConstantStreamOp,
FilterOp,
MapOp,
MergeJoinOp,
ReduceOp,
TabularViewToStreamOp,
ValueToStreamOp,
Expand All @@ -1587,6 +1819,7 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder,
ConstantStreamOp,
FilterOp,
MapOp,
MergeJoinOp,
ReduceOp,
TabularViewToStreamOp,
ValueToStreamOp,
Expand All @@ -1606,6 +1839,7 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder,
ConstantStreamOp,
FilterOp,
MapOp,
MergeJoinOp,
ReduceOp,
TabularViewToStreamOp,
ValueToStreamOp,
Expand Down
Loading