From ccaf772f5155a824ee36f57dfd15e65347791a0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Tue, 28 Mar 2023 05:28:58 +0000 Subject: [PATCH 01/10] xxx-tablegen --- .../Dialect/Iterators/IR/IteratorsOps.td | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td b/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td index e8e47af68533..722d7763ddac 100644 --- a/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td +++ b/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td @@ -258,6 +258,28 @@ def Iterators_MapOp : Iterators_Op<"map", }]; } +def Iterators_MergeJoinOp : Iterators_Op<"mergejoin", // XXX: add type constraint + [DeclareOpInterfaceMethods]> { + let summary = "Join two sorted streams of tuples on their first element."; + let description = [{ + }]; + let arguments = (ins + Iterators_StreamOf:$lhs, + Iterators_StreamOf:$rhs + ); + let results = (outs Iterators_StreamOf:$result); + let assemblyFormat = [{ + $lhs `and` $rhs attr-dict `:` functional-type(operands, results) + }]; + let extraClassDefinition = [{ + /// Implement OpAsmOpInterface. + void $cppClass::getAsmResultNames( + llvm::function_ref setNameFn) { + setNameFn(getResult(), "joined"); + } + }]; +} + def Iterators_ReduceOp : Iterators_Op<"reduce", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { From c86ae1ed71f3dac04498030679ad3b75adffaf38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Tue, 11 Apr 2023 14:01:27 +0000 Subject: [PATCH 02/10] while-op.mlir --- experimental/iterators/while-op.mlir | 212 +++++++++++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100644 experimental/iterators/while-op.mlir diff --git a/experimental/iterators/while-op.mlir b/experimental/iterators/while-op.mlir new file mode 100644 index 000000000000..3ccbeaa134ad --- /dev/null +++ b/experimental/iterators/while-op.mlir @@ -0,0 +1,212 @@ + llvm.mlir.global internal constant @iterators.frmt_spec.0("-\0A\00") {addr_space = 0 : i32} + llvm.func @printf(!llvm.ptr, ...) -> i32 + llvm.mlir.global internal constant @iterators.frmt_spec("(%llu, %llu)\0A\00") {addr_space = 0 : i32} + func.func private @iterators.constantstream.close.1(%arg0: !iterators.state) -> !iterators.state { + return %arg0 : !iterators.state + } + llvm.mlir.global internal constant @iterators.constant_stream_data.1() {addr_space = 0 : i32} : !llvm.array<4 x struct<(i32)>> { + %0 = llvm.mlir.undef : !llvm.array<4 x struct<(i32)>> + %1 = llvm.mlir.undef : !llvm.struct<(i32)> + %2 = llvm.mlir.constant(2 : i32) : i32 + %3 = llvm.insertvalue %2, %1[0] : !llvm.struct<(i32)> + %4 = llvm.insertvalue %3, %0[0] : !llvm.array<4 x struct<(i32)>> + %5 = llvm.mlir.undef : !llvm.struct<(i32)> + %6 = llvm.mlir.constant(4 : i32) : i32 + %7 = llvm.insertvalue %6, %5[0] : !llvm.struct<(i32)> + %8 = llvm.insertvalue %7, %4[1] : !llvm.array<4 x struct<(i32)>> + %9 = llvm.mlir.undef : !llvm.struct<(i32)> + %10 = llvm.mlir.constant(6 : i32) : i32 + %11 = llvm.insertvalue %10, %9[0] : !llvm.struct<(i32)> + %12 = llvm.insertvalue %11, %8[2] : !llvm.array<4 x struct<(i32)>> + %13 = llvm.mlir.undef : !llvm.struct<(i32)> + %14 = llvm.mlir.constant(8 : i32) : i32 + %15 = llvm.insertvalue %14, %13[0] : !llvm.struct<(i32)> + %16 = llvm.insertvalue %15, %12[3] : !llvm.array<4 x struct<(i32)>> + llvm.return %16 : !llvm.array<4 x struct<(i32)>> + } + func.func private @iterators.constantstream.next.1(%arg0: !iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) { + %0 = iterators.extractvalue %arg0[0] : !iterators.state + %c4_i32 = arith.constant 4 : i32 + %1 = arith.cmpi slt, %0, %c4_i32 : i32 + %2:2 = scf.if %1 -> (!iterators.state, !llvm.struct<(i32)>) { + %c1_i32 = arith.constant 1 : i32 + %3 = arith.addi %0, %c1_i32 : i32 + %state = iterators.insertvalue %3 into %arg0[0] : !iterators.state + %4 = llvm.mlir.addressof @iterators.constant_stream_data.1 : !llvm.ptr + %5 = llvm.getelementptr %4[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(i32)> + %6 = llvm.load %5 : !llvm.ptr -> !llvm.struct<(i32)> + scf.yield %state, %6 : !iterators.state, !llvm.struct<(i32)> + } else { + %3 = llvm.mlir.undef : !llvm.struct<(i32)> + scf.yield %arg0, %3 : !iterators.state, !llvm.struct<(i32)> + } + return %2#0, %1, %2#1 : !iterators.state, i1, !llvm.struct<(i32)> + } + func.func private @iterators.constantstream.open.1(%arg0: !iterators.state) -> !iterators.state { + %c0_i32 = arith.constant 0 : i32 + %state = iterators.insertvalue %c0_i32 into %arg0[0] : !iterators.state + return %state : !iterators.state + } + func.func private @iterators.constantstream.close.0(%arg0: !iterators.state) -> !iterators.state { + return %arg0 : !iterators.state + } + llvm.mlir.global internal constant @iterators.constant_stream_data.0() {addr_space = 0 : i32} : !llvm.array<4 x struct<(i32)>> { + %0 = llvm.mlir.undef : !llvm.array<4 x struct<(i32)>> + %1 = llvm.mlir.undef : !llvm.struct<(i32)> + %2 = llvm.mlir.constant(0 : i32) : i32 + %3 = llvm.insertvalue %2, %1[0] : !llvm.struct<(i32)> + %4 = llvm.insertvalue %3, %0[0] : !llvm.array<4 x struct<(i32)>> + %5 = llvm.mlir.undef : !llvm.struct<(i32)> + %6 = llvm.mlir.constant(1 : i32) : i32 + %7 = llvm.insertvalue %6, %5[0] : !llvm.struct<(i32)> + %8 = llvm.insertvalue %7, %4[1] : !llvm.array<4 x struct<(i32)>> + %9 = llvm.mlir.undef : !llvm.struct<(i32)> + %10 = llvm.mlir.constant(2 : i32) : i32 + %11 = llvm.insertvalue %10, %9[0] : !llvm.struct<(i32)> + %12 = llvm.insertvalue %11, %8[2] : !llvm.array<4 x struct<(i32)>> + %13 = llvm.mlir.undef : !llvm.struct<(i32)> + %14 = llvm.mlir.constant(3 : i32) : i32 + %15 = llvm.insertvalue %14, %13[0] : !llvm.struct<(i32)> + %16 = llvm.insertvalue %15, %12[3] : !llvm.array<4 x struct<(i32)>> + llvm.return %16 : !llvm.array<4 x struct<(i32)>> + } + func.func private @iterators.constantstream.next.0(%arg0: !iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) { + %0 = iterators.extractvalue %arg0[0] : !iterators.state + %c4_i32 = arith.constant 4 : i32 + %1 = arith.cmpi slt, %0, %c4_i32 : i32 + %2:2 = scf.if %1 -> (!iterators.state, !llvm.struct<(i32)>) { + %c1_i32 = arith.constant 1 : i32 + %3 = arith.addi %0, %c1_i32 : i32 + %state = iterators.insertvalue %3 into %arg0[0] : !iterators.state + %4 = llvm.mlir.addressof @iterators.constant_stream_data.0 : !llvm.ptr + %5 = llvm.getelementptr %4[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(i32)> + %6 = llvm.load %5 : !llvm.ptr -> !llvm.struct<(i32)> + scf.yield %state, %6 : !iterators.state, !llvm.struct<(i32)> + } else { + %3 = llvm.mlir.undef : !llvm.struct<(i32)> + scf.yield %arg0, %3 : !iterators.state, !llvm.struct<(i32)> + } + return %2#0, %1, %2#1 : !iterators.state, i1, !llvm.struct<(i32)> + } + func.func private @iterators.constantstream.open.0(%arg0: !iterators.state) -> !iterators.state { + %c0_i32 = arith.constant 0 : i32 + %state = iterators.insertvalue %c0_i32 into %arg0[0] : !iterators.state + return %state : !iterators.state + } + !state_type = + !iterators.state< + !iterators.state, // lhs state + !iterators.state, // rhs state + !llvm.struct<(i32)>, i1, // lhs value, hasValue + !llvm.struct<(i32)>, i1 // rhs value, hasValue + > + func.func private @iterators.constantstream.close.2(%arg0: !state_type) -> !state_type { + %lhs_state = iterators.extractvalue %arg0[0] : !state_type + %rhs_state = iterators.extractvalue %arg0[1] : !state_type + %0 = call @iterators.constantstream.close.0(%lhs_state) : (!iterators.state) -> !iterators.state + %1 = call @iterators.constantstream.close.0(%rhs_state) : (!iterators.state) -> !iterators.state + %state_0 = iterators.insertvalue %0 into %arg0[0] : !state_type + %state_1 = iterators.insertvalue %1 into %state_0[0] : !state_type + return %state_1 : !state_type + } + func.func private @iterators.constantstream.next.2(%arg0: !state_type) -> (!state_type, i1, !llvm.struct<(i32, i32)>) { + // Pseudocode: + // value = undef + // hasValue = false + // if !lhsHasValue: + // lhsState, lhsHasValue, lhsValue = next(lhsState) + // if !rhsHasValue: + // rhsState, rhsHasValue, rhsValue = next(rhsState) + // while (lhsHasValue && rhsHasValue) + // if lhsValue < rhsValue: + // lhsState, lhsHasValue, lhsValue = next(lhsState) + // // return half-undef tuple for outer join + // continue + // if lhsValue > rhsValue: + // rhsState, rhsHasValue, rhsValue = next(rhsState) + // // return half-undef tuple for outer join + // continue + // // assert (lhsValue == rhsValue) + // value = tuple(lhsValue, rhsValue) + // hasValue = true + // lhsHasValue = false + // rhsHasValue = false + // break + // return state, hasValue, value + + // Fetch initial upstream elements if required. + %initialLhsState = iterators.extractvalue %arg0[0] : !state_type + %initialRhsState = iterators.extractvalue %arg0[1] : !state_type + %initialLhsHasValue = iterators.extractvalue %arg0[3] : !state_type + %initialRhsHasValue = iterators.extractvalue %arg0[5] : !state_type + %updatedLhsState, %lhsHasValue, %lhsValue = scf.if %initialLhsHasValue -> (!iterators.state, i1, !llvm.struct<(i32)>) { + %initialLhsValue = iterators.extractvalue %arg0[2] : !state_type + scf.yield %initialLhsState, %initialLhsHasValue, %initialLhsValue : !iterators.state, i1, !llvm.struct<(i32)> + } else { + %nextResult:3 = func.call @iterators.constantstream.next.0(%initialLhsState) : (!iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) + scf.yield %nextResult#0, %nextResult#1, %nextResult#2 : !iterators.state, i1, !llvm.struct<(i32)> + } + %updatedRhsState, %rhsHasValue, %rhsValue = scf.if %initialRhsHasValue -> (!iterators.state, i1, !llvm.struct<(i32)>) { + %initialRhsValue = iterators.extractvalue %arg0[2] : !state_type + scf.yield %initialRhsState, %initialRhsHasValue, %initialRhsValue : !iterators.state, i1, !llvm.struct<(i32)> + } else { + %nextResult:3 = func.call @iterators.constantstream.next.0(%initialRhsState) : (!iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) + scf.yield %nextResult#0, %nextResult#1, %nextResult#2 : !iterators.state, i1, !llvm.struct<(i32)> + } + + // Main while loop looking for a match. + %finalLhsState, %finalLhsHasValue, %finalLhsValue, %finalRhsState, %finalRhsHasValue, %finalRhsValue = + scf.while(%loopLhsState = %updatedLhsState, %loopLhsHasValue = %lhsHasValue, %loopLhsValue = %lhsValue, + %loopRhsState = %updatedRhsState, %loopRhsHasValue = %rhsHasValue, %loopRhsValue = %rhsValue) + : (!iterators.state, i1, !llvm.struct<(i32)>, !iterators.state, i1, !llvm.struct<(i32)>) + -> (!iterators.state, i1, !llvm.struct<(i32)>, !iterators.state, i1, !llvm.struct<(i32)>) { + scf.condition + } do { + + } + + + %0 = llvm.mlir.undef : !llvm.struct<(i32, i32)> + %1 = arith.constant false + return %arg0, %1, %0 : !state_type, i1, !llvm.struct<(i32, i32)> + } + func.func private @iterators.constantstream.open.2(%arg0: !state_type) -> !state_type { + %lhs_state = iterators.extractvalue %arg0[0] : !state_type + %rhs_state = iterators.extractvalue %arg0[1] : !state_type + %0 = call @iterators.constantstream.open.0(%lhs_state) : (!iterators.state) -> !iterators.state + %1 = call @iterators.constantstream.open.0(%rhs_state) : (!iterators.state) -> !iterators.state + %false = arith.constant false + %state_0 = iterators.insertvalue %0 into %arg0[0] : !state_type + %state_1 = iterators.insertvalue %1 into %state_0[1] : !state_type + %state_2 = iterators.insertvalue %false into %state_1[3] : !state_type + %state_3 = iterators.insertvalue %false into %state_2[5] : !state_type + return %state_3 : !state_type + } + func.func @main() { + %c0_i32 = arith.constant 0 : i32 + %state = iterators.createstate(%c0_i32) : !iterators.state + %state_1 = iterators.createstate(%c0_i32) : !iterators.state + %undef = llvm.mlir.undef : !llvm.struct<(i32)> + %false = arith.constant false + %state_3 = iterators.createstate(%state, %state_1, %undef, %false, %undef, %false) : !state_type + %0 = call @iterators.constantstream.open.2(%state_3) : (!state_type) -> !state_type + %1:2 = scf.while (%arg0 = %0) : (!state_type) -> (!state_type, !llvm.struct<(i32, i32)>) { + %6:3 = func.call @iterators.constantstream.next.2(%arg0) : (!state_type) -> (!state_type, i1, !llvm.struct<(i32, i32)>) + scf.condition(%6#1) %6#0, %6#2 : !state_type, !llvm.struct<(i32, i32)> + } do { + ^bb0(%arg0: !state_type, %arg1: !llvm.struct<(i32, i32)>): + %6 = llvm.extractvalue %arg1[0] : !llvm.struct<(i32, i32)> + %7 = arith.extui %6 : i32 to i64 + %8 = llvm.extractvalue %arg1[1] : !llvm.struct<(i32, i32)> + %9 = arith.extui %8 : i32 to i64 + %10 = llvm.mlir.addressof @iterators.frmt_spec : !llvm.ptr + %11 = llvm.getelementptr %10[0] : (!llvm.ptr) -> !llvm.ptr, i8 + %12 = llvm.call @printf(%11, %7, %9) : (!llvm.ptr, i64, i64) -> i32 + scf.yield %arg0 : !state_type + } + %2 = call @iterators.constantstream.close.2(%1#0) : (!state_type) -> !state_type + %3 = llvm.mlir.addressof @iterators.frmt_spec.0 : !llvm.ptr + %4 = llvm.getelementptr %3[0] : (!llvm.ptr) -> !llvm.ptr, i8 + %5 = llvm.call @printf(%4) : (!llvm.ptr) -> i32 + return + } From 311b9983d0928ffc20be40c2bf188933785f5abb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Tue, 11 Apr 2023 14:35:18 +0000 Subject: [PATCH 03/10] while-op.mlir --- experimental/iterators/while-op.mlir | 44 +++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/experimental/iterators/while-op.mlir b/experimental/iterators/while-op.mlir index 3ccbeaa134ad..4642a9dedd60 100644 --- a/experimental/iterators/while-op.mlir +++ b/experimental/iterators/while-op.mlir @@ -160,15 +160,51 @@ %loopRhsState = %updatedRhsState, %loopRhsHasValue = %rhsHasValue, %loopRhsValue = %rhsValue) : (!iterators.state, i1, !llvm.struct<(i32)>, !iterators.state, i1, !llvm.struct<(i32)>) -> (!iterators.state, i1, !llvm.struct<(i32)>, !iterators.state, i1, !llvm.struct<(i32)>) { - scf.condition + // If both sides still have a value (i.e., they have not reached the end of their stream) but the current values are different, we need to continue the main loop to find a matching pair. + %bothSidesHaveValue = arith.andi %loopLhsHasValue, %loopRhsHasValue : i1 + %lhsi = llvm.extractvalue %loopLhsValue[0] : !llvm.struct<(i32)> + %rhsi = llvm.extractvalue %loopRhsValue[0] : !llvm.struct<(i32)> + %valuesNotEqual = arith.cmpi "ne", %lhsi, %rhsi : i32 + %continue = arith.andi %bothSidesHaveValue, %valuesNotEqual : i1 + scf.condition (%continue) + %loopLhsState, %loopLhsHasValue, %loopLhsValue, %loopRhsState, %loopRhsHasValue, %loopRhsValue + : !iterators.state, i1, !llvm.struct<(i32)>, !iterators.state, i1, !llvm.struct<(i32)> } do { + ^bb(%loopLhsState: !iterators.state, %loopLhsHasValue: i1, %loopLhsValue: !llvm.struct<(i32)>, + %loopRhsState: !iterators.state, %loopRhsHasValue: i1, %loopRhsValue: !llvm.struct<(i32)>): + %lhsi = llvm.extractvalue %loopLhsValue[0] : !llvm.struct<(i32)> + %rhsi = llvm.extractvalue %loopRhsValue[0] : !llvm.struct<(i32)> + %isLhsSmaller = arith.cmpi "slt", %lhsi, %rhsi : i32 + %branchedLhsState, %branchedLhsHasValue, %branchedLhsValue, %branchedRhsState, %branchedRhsHasValue, %branchedRhsValue = + scf.if %isLhsSmaller -> (!iterators.state, i1, !llvm.struct<(i32)>, !iterators.state, i1, !llvm.struct<(i32)>) { + // If the LHS value was smaller, we need to advance the LHS input. + %nextLhsState, %nextLhsHasValue, %nextLhsValue = func.call @iterators.constantstream.next.0(%loopLhsState) : (!iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) + scf.yield %nextLhsState, %nextLhsHasValue, %nextLhsValue, %loopRhsState, %loopRhsHasValue, %loopRhsValue : !iterators.state, i1, !llvm.struct<(i32)>, !iterators.state, i1, !llvm.struct<(i32)> + } else { + // If the RHS value was smaller, we need to advance the RHS input. + %nextRhsState, %nextRhsHasValue, %nextRhsValue = func.call @iterators.constantstream.next.0(%loopRhsState) : (!iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) + scf.yield %loopLhsState, %loopLhsHasValue, %loopLhsValue, %nextRhsState, %nextRhsHasValue, %nextRhsValue : !iterators.state, i1, !llvm.struct<(i32)>, !iterators.state, i1, !llvm.struct<(i32)> + } + scf.yield %branchedLhsState, %branchedLhsHasValue, %branchedLhsValue, %branchedRhsState, %branchedRhsHasValue, %branchedRhsValue + : !iterators.state, i1, !llvm.struct<(i32)>, !iterators.state, i1, !llvm.struct<(i32)> } + // Update state. Set lhsHasvalue and rhsHasValue to false because emitting a result consumes them. + %false = arith.constant false + %updatedState = iterators.createstate(%finalLhsState, %finalRhsState, + %finalLhsValue, %false, + %finalRhsValue, %false) : !state_type + + // Concatenate the two structs. (This is working on undefined structs if one of the two streams has finished, i.e., if %bothSidesHaveValue is false.) + %bothSidesHaveValue = arith.andi %finalLhsHasValue, %finalRhsHasValue : i1 + %lhsi = llvm.extractvalue %finalLhsValue[0] : !llvm.struct<(i32)> + %rhsi = llvm.extractvalue %finalRhsValue[0] : !llvm.struct<(i32)> + %structu = llvm.mlir.undef : !llvm.struct<(i32, i32)> + %struct0 = llvm.insertvalue %lhsi, %structu[0] : !llvm.struct<(i32, i32)> + %struct1 = llvm.insertvalue %rhsi, %struct0[1] : !llvm.struct<(i32, i32)> - %0 = llvm.mlir.undef : !llvm.struct<(i32, i32)> - %1 = arith.constant false - return %arg0, %1, %0 : !state_type, i1, !llvm.struct<(i32, i32)> + return %updatedState, %bothSidesHaveValue, %struct1 : !state_type, i1, !llvm.struct<(i32, i32)> } func.func private @iterators.constantstream.open.2(%arg0: !state_type) -> !state_type { %lhs_state = iterators.extractvalue %arg0[0] : !state_type From 93b1286222fdb935a6bebc692b15e5559cb17c17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Tue, 11 Apr 2023 15:03:04 +0000 Subject: [PATCH 04/10] while-op.mlir --- .../iterators/{ => test}/while-op.mlir | 50 +++++++++++++++---- 1 file changed, 41 insertions(+), 9 deletions(-) rename experimental/iterators/{ => test}/while-op.mlir (87%) diff --git a/experimental/iterators/while-op.mlir b/experimental/iterators/test/while-op.mlir similarity index 87% rename from experimental/iterators/while-op.mlir rename to experimental/iterators/test/while-op.mlir index 4642a9dedd60..350133c5f006 100644 --- a/experimental/iterators/while-op.mlir +++ b/experimental/iterators/test/while-op.mlir @@ -1,3 +1,11 @@ +// RUN: iterators-opt %s \ +// RUN: -convert-iterators-to-llvm \ +// RUN: -decompose-iterator-states \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-func-to-llvm \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: | FileCheck %s + llvm.mlir.global internal constant @iterators.frmt_spec.0("-\0A\00") {addr_space = 0 : i32} llvm.func @printf(!llvm.ptr, ...) -> i32 llvm.mlir.global internal constant @iterators.frmt_spec("(%llu, %llu)\0A\00") {addr_space = 0 : i32} @@ -7,19 +15,19 @@ llvm.mlir.global internal constant @iterators.constant_stream_data.1() {addr_space = 0 : i32} : !llvm.array<4 x struct<(i32)>> { %0 = llvm.mlir.undef : !llvm.array<4 x struct<(i32)>> %1 = llvm.mlir.undef : !llvm.struct<(i32)> - %2 = llvm.mlir.constant(2 : i32) : i32 + %2 = llvm.mlir.constant(0 : i32) : i32 %3 = llvm.insertvalue %2, %1[0] : !llvm.struct<(i32)> %4 = llvm.insertvalue %3, %0[0] : !llvm.array<4 x struct<(i32)>> %5 = llvm.mlir.undef : !llvm.struct<(i32)> - %6 = llvm.mlir.constant(4 : i32) : i32 + %6 = llvm.mlir.constant(2 : i32) : i32 %7 = llvm.insertvalue %6, %5[0] : !llvm.struct<(i32)> %8 = llvm.insertvalue %7, %4[1] : !llvm.array<4 x struct<(i32)>> %9 = llvm.mlir.undef : !llvm.struct<(i32)> - %10 = llvm.mlir.constant(6 : i32) : i32 + %10 = llvm.mlir.constant(4 : i32) : i32 %11 = llvm.insertvalue %10, %9[0] : !llvm.struct<(i32)> %12 = llvm.insertvalue %11, %8[2] : !llvm.array<4 x struct<(i32)>> %13 = llvm.mlir.undef : !llvm.struct<(i32)> - %14 = llvm.mlir.constant(8 : i32) : i32 + %14 = llvm.mlir.constant(6 : i32) : i32 %15 = llvm.insertvalue %14, %13[0] : !llvm.struct<(i32)> %16 = llvm.insertvalue %15, %12[3] : !llvm.array<4 x struct<(i32)>> llvm.return %16 : !llvm.array<4 x struct<(i32)>> @@ -70,7 +78,7 @@ %16 = llvm.insertvalue %15, %12[3] : !llvm.array<4 x struct<(i32)>> llvm.return %16 : !llvm.array<4 x struct<(i32)>> } - func.func private @iterators.constantstream.next.0(%arg0: !iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) { + func.func private @iterators.constantstream.next.0.lhs(%arg0: !iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) { %0 = iterators.extractvalue %arg0[0] : !iterators.state %c4_i32 = arith.constant 4 : i32 %1 = arith.cmpi slt, %0, %c4_i32 : i32 @@ -88,6 +96,24 @@ } return %2#0, %1, %2#1 : !iterators.state, i1, !llvm.struct<(i32)> } + func.func private @iterators.constantstream.next.0.rhs(%arg0: !iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) { + %0 = iterators.extractvalue %arg0[0] : !iterators.state + %c4_i32 = arith.constant 4 : i32 + %1 = arith.cmpi slt, %0, %c4_i32 : i32 + %2:2 = scf.if %1 -> (!iterators.state, !llvm.struct<(i32)>) { + %c1_i32 = arith.constant 1 : i32 + %3 = arith.addi %0, %c1_i32 : i32 + %state = iterators.insertvalue %3 into %arg0[0] : !iterators.state + %4 = llvm.mlir.addressof @iterators.constant_stream_data.1 : !llvm.ptr + %5 = llvm.getelementptr %4[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(i32)> + %6 = llvm.load %5 : !llvm.ptr -> !llvm.struct<(i32)> + scf.yield %state, %6 : !iterators.state, !llvm.struct<(i32)> + } else { + %3 = llvm.mlir.undef : !llvm.struct<(i32)> + scf.yield %arg0, %3 : !iterators.state, !llvm.struct<(i32)> + } + return %2#0, %1, %2#1 : !iterators.state, i1, !llvm.struct<(i32)> + } func.func private @iterators.constantstream.open.0(%arg0: !iterators.state) -> !iterators.state { %c0_i32 = arith.constant 0 : i32 %state = iterators.insertvalue %c0_i32 into %arg0[0] : !iterators.state @@ -143,14 +169,14 @@ %initialLhsValue = iterators.extractvalue %arg0[2] : !state_type scf.yield %initialLhsState, %initialLhsHasValue, %initialLhsValue : !iterators.state, i1, !llvm.struct<(i32)> } else { - %nextResult:3 = func.call @iterators.constantstream.next.0(%initialLhsState) : (!iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) + %nextResult:3 = func.call @iterators.constantstream.next.0.lhs(%initialLhsState) : (!iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) scf.yield %nextResult#0, %nextResult#1, %nextResult#2 : !iterators.state, i1, !llvm.struct<(i32)> } %updatedRhsState, %rhsHasValue, %rhsValue = scf.if %initialRhsHasValue -> (!iterators.state, i1, !llvm.struct<(i32)>) { %initialRhsValue = iterators.extractvalue %arg0[2] : !state_type scf.yield %initialRhsState, %initialRhsHasValue, %initialRhsValue : !iterators.state, i1, !llvm.struct<(i32)> } else { - %nextResult:3 = func.call @iterators.constantstream.next.0(%initialRhsState) : (!iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) + %nextResult:3 = func.call @iterators.constantstream.next.0.rhs(%initialRhsState) : (!iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) scf.yield %nextResult#0, %nextResult#1, %nextResult#2 : !iterators.state, i1, !llvm.struct<(i32)> } @@ -178,11 +204,11 @@ %branchedLhsState, %branchedLhsHasValue, %branchedLhsValue, %branchedRhsState, %branchedRhsHasValue, %branchedRhsValue = scf.if %isLhsSmaller -> (!iterators.state, i1, !llvm.struct<(i32)>, !iterators.state, i1, !llvm.struct<(i32)>) { // If the LHS value was smaller, we need to advance the LHS input. - %nextLhsState, %nextLhsHasValue, %nextLhsValue = func.call @iterators.constantstream.next.0(%loopLhsState) : (!iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) + %nextLhsState, %nextLhsHasValue, %nextLhsValue = func.call @iterators.constantstream.next.0.lhs(%loopLhsState) : (!iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) scf.yield %nextLhsState, %nextLhsHasValue, %nextLhsValue, %loopRhsState, %loopRhsHasValue, %loopRhsValue : !iterators.state, i1, !llvm.struct<(i32)>, !iterators.state, i1, !llvm.struct<(i32)> } else { // If the RHS value was smaller, we need to advance the RHS input. - %nextRhsState, %nextRhsHasValue, %nextRhsValue = func.call @iterators.constantstream.next.0(%loopRhsState) : (!iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) + %nextRhsState, %nextRhsHasValue, %nextRhsValue = func.call @iterators.constantstream.next.0.rhs(%loopRhsState) : (!iterators.state) -> (!iterators.state, i1, !llvm.struct<(i32)>) scf.yield %loopLhsState, %loopLhsHasValue, %loopLhsValue, %nextRhsState, %nextRhsHasValue, %nextRhsValue : !iterators.state, i1, !llvm.struct<(i32)>, !iterators.state, i1, !llvm.struct<(i32)> } @@ -218,7 +244,13 @@ %state_3 = iterators.insertvalue %false into %state_2[5] : !state_type return %state_3 : !state_type } + + // CHECK-LABEL: while-op + // CHECK-NEXT: (0, 0) + // CHECK-NEXT: (2, 2) + // CHECK-NEXT: - func.func @main() { + iterators.print("while-op") %c0_i32 = arith.constant 0 : i32 %state = iterators.createstate(%c0_i32) : !iterators.state %state_1 = iterators.createstate(%c0_i32) : !iterators.state From cb9e8137a793defa2fbd651f57dde83c83297bb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 12 Apr 2023 09:02:55 +0000 Subject: [PATCH 05/10] while-op.mlir --- experimental/iterators/test/while-op.mlir | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/experimental/iterators/test/while-op.mlir b/experimental/iterators/test/while-op.mlir index 350133c5f006..777bfd07a55e 100644 --- a/experimental/iterators/test/while-op.mlir +++ b/experimental/iterators/test/while-op.mlir @@ -126,7 +126,7 @@ !llvm.struct<(i32)>, i1, // lhs value, hasValue !llvm.struct<(i32)>, i1 // rhs value, hasValue > - func.func private @iterators.constantstream.close.2(%arg0: !state_type) -> !state_type { + func.func private @iterators.mergejoin.close.2(%arg0: !state_type) -> !state_type { %lhs_state = iterators.extractvalue %arg0[0] : !state_type %rhs_state = iterators.extractvalue %arg0[1] : !state_type %0 = call @iterators.constantstream.close.0(%lhs_state) : (!iterators.state) -> !iterators.state @@ -135,7 +135,7 @@ %state_1 = iterators.insertvalue %1 into %state_0[0] : !state_type return %state_1 : !state_type } - func.func private @iterators.constantstream.next.2(%arg0: !state_type) -> (!state_type, i1, !llvm.struct<(i32, i32)>) { + func.func private @iterators.mergejoin.next.2(%arg0: !state_type) -> (!state_type, i1, !llvm.struct<(i32, i32)>) { // Pseudocode: // value = undef // hasValue = false @@ -217,6 +217,7 @@ } // Update state. Set lhsHasvalue and rhsHasValue to false because emitting a result consumes them. + // XXX: Alternatively, we could call next on both sides here. %false = arith.constant false %updatedState = iterators.createstate(%finalLhsState, %finalRhsState, %finalLhsValue, %false, @@ -232,7 +233,7 @@ return %updatedState, %bothSidesHaveValue, %struct1 : !state_type, i1, !llvm.struct<(i32, i32)> } - func.func private @iterators.constantstream.open.2(%arg0: !state_type) -> !state_type { + func.func private @iterators.mergejoin.open.2(%arg0: !state_type) -> !state_type { %lhs_state = iterators.extractvalue %arg0[0] : !state_type %rhs_state = iterators.extractvalue %arg0[1] : !state_type %0 = call @iterators.constantstream.open.0(%lhs_state) : (!iterators.state) -> !iterators.state @@ -257,9 +258,9 @@ %undef = llvm.mlir.undef : !llvm.struct<(i32)> %false = arith.constant false %state_3 = iterators.createstate(%state, %state_1, %undef, %false, %undef, %false) : !state_type - %0 = call @iterators.constantstream.open.2(%state_3) : (!state_type) -> !state_type + %0 = call @iterators.mergejoin.open.2(%state_3) : (!state_type) -> !state_type %1:2 = scf.while (%arg0 = %0) : (!state_type) -> (!state_type, !llvm.struct<(i32, i32)>) { - %6:3 = func.call @iterators.constantstream.next.2(%arg0) : (!state_type) -> (!state_type, i1, !llvm.struct<(i32, i32)>) + %6:3 = func.call @iterators.mergejoin.next.2(%arg0) : (!state_type) -> (!state_type, i1, !llvm.struct<(i32, i32)>) scf.condition(%6#1) %6#0, %6#2 : !state_type, !llvm.struct<(i32, i32)> } do { ^bb0(%arg0: !state_type, %arg1: !llvm.struct<(i32, i32)>): @@ -272,7 +273,7 @@ %12 = llvm.call @printf(%11, %7, %9) : (!llvm.ptr, i64, i64) -> i32 scf.yield %arg0 : !state_type } - %2 = call @iterators.constantstream.close.2(%1#0) : (!state_type) -> !state_type + %2 = call @iterators.mergejoin.close.2(%1#0) : (!state_type) -> !state_type %3 = llvm.mlir.addressof @iterators.frmt_spec.0 : !llvm.ptr %4 = llvm.getelementptr %3[0] : (!llvm.ptr) -> !llvm.ptr, i8 %5 = llvm.call @printf(%4) : (!llvm.ptr) -> i32 From 85cc35fb1c68762e66ec5dba7b578b6d01115792 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 12 Apr 2023 13:26:10 +0000 Subject: [PATCH 06/10] xxx-tablegen --- .../include/iterators/Dialect/Iterators/IR/IteratorsOps.td | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td b/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td index 722d7763ddac..27d246304618 100644 --- a/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td +++ b/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td @@ -263,10 +263,7 @@ def Iterators_MergeJoinOp : Iterators_Op<"mergejoin", // XXX: add type constrain let summary = "Join two sorted streams of tuples on their first element."; let description = [{ }]; - let arguments = (ins - Iterators_StreamOf:$lhs, - Iterators_StreamOf:$rhs - ); + let arguments = (ins Iterators_Stream:$lhs, Iterators_Stream:$rhs); let results = (outs Iterators_StreamOf:$result); let assemblyFormat = [{ $lhs `and` $rhs attr-dict `:` functional-type(operands, results) From 34073678d449fb7c7909f0e5a337bb90d22e55e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 12 Apr 2023 13:26:53 +0000 Subject: [PATCH 07/10] xxx-lowering --- .../IteratorsToLLVM/IteratorAnalysis.cpp | 21 ++ .../IteratorsToLLVM/IteratorsToLLVM.cpp | 235 ++++++++++++++++++ 2 files changed, 256 insertions(+) diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp index acf620b37e10..71880417331e 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp @@ -87,6 +87,26 @@ StateTypeComputer::operator()(MapOp op, return StateType::get(context, {upstreamStateTypes[0]}); } +/// The state of MergeJoinOp consists of (1) the states of the two upstream ops, +/// (2) the last element successfully consumed from each of the two upstream +/// ops if any, and (3) two Booleans indicating whether these elements exist, +/// respectively. +template <> +StateType +StateTypeComputer::operator()(MergeJoinOp op, + llvm::SmallVector upstreamStateTypes) { + MLIRContext *context = op->getContext(); + StateType lhsStateType = upstreamStateTypes[0]; + StateType rhsStateType = upstreamStateTypes[1]; + auto lhsStreamType = op.getLhs().getType().cast(); + auto rhsStreamType = op.getRhs().getType().cast(); + Type lhsElementType = lhsStreamType.getElementType(); + Type rhsElementType = rhsStreamType.getElementType(); + Type i1 = IntegerType::get(context, /*width=*/1); + return StateType::get(context, {lhsStateType, i1, lhsElementType, + rhsStateType, i1, rhsElementType}); +} + /// The state of ReduceOp only consists of the state of its upstream iterator, /// i.e., the state of the iterator that produces its input stream. template <> @@ -183,6 +203,7 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis( ConstantStreamOp, FilterOp, MapOp, + MergeJoinOp, ReduceOp, TabularViewToStreamOp, ValueToStreamOp, diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp index 0fef8b6694ea..41140284a2f1 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp @@ -864,6 +864,237 @@ static Value buildStateCreation(MapOp op, MapOp::Adaptor adaptor, return b.create(stateType, upstreamState); } +//===----------------------------------------------------------------------===// +// MapOp. +//===----------------------------------------------------------------------===// + +/// XXX +static Value buildOpenBody(MergeJoinOp op, OpBuilder &builder, + Value initialState, + ArrayRef upstreamInfos) { + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + // Open both upstream states. + Value state = initialState; + for (auto i : {0, 1}) { + Type upstreamStateType = upstreamInfos[i].stateType; + IntegerAttr fieldIndex = b.getIndexAttr(i * 3); + + // Extract upstream state. + Value initialUpstreamState = b.create( + upstreamStateType, initialState, fieldIndex); + + // Call Open on upstream. + SymbolRefAttr openFunc = upstreamInfos[i].openFunc; + auto callOp = b.create(openFunc, upstreamStateType, + initialUpstreamState); + + // Update upstream state. + Value updatedUpstreamState = callOp->getResult(0); + state = b.create(initialState, fieldIndex, + updatedUpstreamState); + } + + return state; +} + +/// XXX +static llvm::SmallVector +buildNextBody(MergeJoinOp op, OpBuilder &builder, Value initialState, + ArrayRef upstreamInfos, Type elementType) { + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + MLIRContext *context = b.getContext(); + Type i1 = IntegerType::get(context, /*width=*/1); + + // Determine various derived types. + auto stateType = initialState.getType().cast(); + Type lhsStateType = upstreamInfos[0].stateType; + Type rhsStateType = upstreamInfos[1].stateType; + auto lhsStreamType = op.getLhs().getType().cast(); + auto rhsStreamType = op.getRhs().getType().cast(); + Type lhsElementType = lhsStreamType.getElementType(); + Type rhsElementType = rhsStreamType.getElementType(); + + TypeRange stateTypes{lhsStateType, rhsStateType}; + TypeRange elementTypes{lhsElementType, rhsElementType}; + + SymbolRefAttr lhsNextFunc = upstreamInfos[0].nextFunc; + SymbolRefAttr rhsNextFunc = upstreamInfos[1].nextFunc; + SmallVector nextFuncs = {lhsNextFunc, rhsNextFunc}; + + // Fetch initial upstream elements if required. + SmallVector upstreamStates(2); + SmallVector upstreamHasElements(2); + SmallVector upstreamElements(2); + for (auto i : {0, 1}) { + Value initialUpstreamState = b.create( + lhsStateType, initialState, b.getIndexAttr(i * 3)); + Value initialHasElement = b.create( + i1, initialState, b.getIndexAttr(i * 3 + 1)); + auto ifOp = b.create( + /*condition=*/initialHasElement, + /*thenBuilder=*/ + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + // The element stored in the state is valid, so take that and don't + // modify the corresponding upstream state. + Value initialElement = b.create( + elementTypes[i], initialState, b.getIndexAttr(i * 3 + 2)); + b.create(ValueRange{initialUpstreamState, + initialHasElement, initialElement}); + }, /*elseBuilder=*/ + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + // The element stored in the state is not valid, so fetch a new one by + // calling Next and return the result of that. + TypeRange resultTypes{stateTypes[i], i1, elementTypes[i]}; + auto callOp = b.create(nextFuncs[i], resultTypes, + initialUpstreamState); + b.create(callOp->getResults()); + }); + upstreamStates[i] = ifOp->getResult(0); + upstreamHasElements[i] = ifOp->getResult(1); + upstreamElements[i] = ifOp->getResult(2); + } + + // Main while loop looking for a match. + ValueRange whileInputs // (force formatting) + {upstreamStates[0], upstreamHasElements[0], upstreamElements[0], + upstreamStates[1], upstreamHasElements[1], upstreamElements[1]}; + scf::WhileOp whileOp = b.create( + /*resultTypes=*/whileInputs.getTypes(), whileInputs, + /*beforeBuilder=*/ + [&](OpBuilder &builder, Location loc, ValueRange args) { + ImplicitLocOpBuilder b(loc, builder); + Value lhsHasElement = args[1]; + Value rhsHasElement = args[4]; + Value bothSidesHaveElement = + b.create(lhsHasElement, rhsHasElement); + + // XXX: make extendible: + Value lhsElement = args[2]; + Value rhsElement = args[5]; + Value elementsNotEqual = b.create( + arith::CmpIPredicate::ne, lhsElement, rhsElement); + + // If the two elements are valid but not the same, we need to continue + // searching for a match. + Value continueLoop = + b.create(bothSidesHaveElement, elementsNotEqual); + b.create(continueLoop, args); + }, + /*afterBuilder=*/ + [&](OpBuilder &builder, Location loc, ValueRange args) { + ImplicitLocOpBuilder b(loc, builder); + Value lhsElement = args[2]; + Value rhsElement = args[5]; + Value lhsIsSmaller = b.create(arith::CmpIPredicate::slt, + lhsElement, rhsElement); + auto ifOp = b.create( + /*condition=*/lhsIsSmaller, + /*thenBuilder=*/ + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + TypeRange resultTypes{stateTypes[0], i1, elementTypes[0]}; + auto callOp = + b.create(nextFuncs[0], resultTypes, args[0]); + Value updatedLhsState = callOp->getResult(0); + Value updatedLhsHasElement = callOp->getResult(1); + Value updatedLhsElement = callOp->getResult(2); + b.create( + ValueRange{updatedLhsState, updatedLhsHasElement, + updatedLhsElement, args[3], args[4], args[5]}); + }, + /*elseBuilder=*/ + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + TypeRange resultTypes{stateTypes[1], i1, elementTypes[1]}; + auto callOp = + b.create(nextFuncs[1], resultTypes, args[3]); + Value updatedRhsState = callOp->getResult(0); + Value updatedRhsHasElement = callOp->getResult(1); + Value updatedRhsElement = callOp->getResult(2); + b.create( + ValueRange{args[0], args[1], args[2], updatedRhsState, + updatedRhsHasElement, updatedRhsElement}); + }); + + b.create(ifOp.getResults()); + }); + + Value finalLhsState = whileOp->getResult(0); + Value finalRhsState = whileOp->getResult(3); + Value finalLhsHasElement = whileOp->getResult(1); + Value finalRhsHasElement = whileOp->getResult(4); + Value finalLhsElement = whileOp->getResult(2); + Value finalRhsElement = whileOp->getResult(5); + + // Update state and compute return values. + Value constFalse = b.create(/*value=*/0, /*width=*/1); + Value updatedState = b.create( + stateType, ValueRange{finalLhsState, constFalse, finalLhsElement, + finalRhsState, constFalse, finalRhsElement}); + + Value nextElement = b.create( + elementType, ValueRange{finalLhsElement, finalRhsElement}); + Value hasNext = + b.create(finalLhsHasElement, finalRhsHasElement); + + return {updatedState, hasNext, nextElement}; +} + +/// XXX +static Value buildCloseBody(MergeJoinOp op, OpBuilder &builder, + Value initialState, + ArrayRef upstreamInfos) { + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + // Close both upstream states. + Value state = initialState; + for (auto i : {0, 1}) { + Type upstreamStateType = upstreamInfos[i].stateType; + IntegerAttr fieldIndex = b.getIndexAttr(i * 3); + + // Extract upstream state. + Value initialUpstreamState = b.create( + upstreamStateType, initialState, fieldIndex); + + // Call Close on upstream. + SymbolRefAttr closeFunc = upstreamInfos[i].closeFunc; + auto callOp = b.create(closeFunc, upstreamStateType, + initialUpstreamState); + + // Update upstream state. + Value updatedUpstreamState = callOp->getResult(0); + state = b.create(initialState, fieldIndex, + updatedUpstreamState); + } + + return state; +} + +/// XXX +static Value buildStateCreation(MergeJoinOp op, MergeJoinOp::Adaptor adaptor, + OpBuilder &builder, StateType stateType) { + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + Value lhsState = adaptor.getLhs(); + Value rhsState = adaptor.getRhs(); + Value constFalse = b.create(/*value=*/0, /*width=*/1); + auto lhsStreamType = op.getLhs().getType().cast(); + auto rhsStreamType = op.getRhs().getType().cast(); + Type lhsElementType = lhsStreamType.getElementType(); + Type rhsElementType = rhsStreamType.getElementType(); + Value lhsUndefElement = b.create(lhsElementType); + Value rhsUndefElement = b.create(rhsElementType); + return b.create( + stateType, ValueRange{lhsState, constFalse, lhsUndefElement, // (force nl) + rhsState, constFalse, rhsUndefElement}); +} + //===----------------------------------------------------------------------===// // ReduceOp. //===----------------------------------------------------------------------===// @@ -1546,6 +1777,7 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder, ConstantStreamOp, FilterOp, MapOp, + MergeJoinOp, ReduceOp, TabularViewToStreamOp, ValueToStreamOp, @@ -1566,6 +1798,7 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState, ConstantStreamOp, FilterOp, MapOp, + MergeJoinOp, ReduceOp, TabularViewToStreamOp, ValueToStreamOp, @@ -1587,6 +1820,7 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder, ConstantStreamOp, FilterOp, MapOp, + MergeJoinOp, ReduceOp, TabularViewToStreamOp, ValueToStreamOp, @@ -1606,6 +1840,7 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder, ConstantStreamOp, FilterOp, MapOp, + MergeJoinOp, ReduceOp, TabularViewToStreamOp, ValueToStreamOp, From 103495d578c8a791832490409524450528eb61e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 12 Apr 2023 13:27:14 +0000 Subject: [PATCH 08/10] xxx-thenbuilder --- .../lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp index 41140284a2f1..ff4d02e7bf03 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp @@ -612,7 +612,7 @@ buildNextBody(FilterOp op, OpBuilder &builder, Value initialState, // If we got an element, apply predicate. auto ifOp = b.create( /*condition=*/hasNext, - /*ifBuilder=*/ + /*thenBuilder=*/ [&](OpBuilder &builder, Location loc) { ImplicitLocOpBuilder b(loc, builder); @@ -783,7 +783,7 @@ buildNextBody(MapOp op, OpBuilder &builder, Value initialState, // If we got an element, apply map function. auto ifOp = b.create( /*condition=*/hasNext, - /*ifBuilder=*/ + /*thenBuilder=*/ [&](OpBuilder &builder, Location loc) { // Apply map function. ImplicitLocOpBuilder b(loc, builder); @@ -1186,7 +1186,7 @@ buildNextBody(ReduceOp op, OpBuilder &builder, Value initialState, Value firstHasNext = firstNextCall->getResult(1); auto ifOp = b.create( /*condition=*/firstHasNext, - /*ifBuilder=*/ + /*thenBuilder=*/ [&](OpBuilder &builder, Location loc) { ImplicitLocOpBuilder b(loc, builder); From fe3e61ab2beffbd2db2fef3b5d722d2a53cebbf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 12 Apr 2023 13:27:43 +0000 Subject: [PATCH 09/10] xxx-newline --- .../iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp index ff4d02e7bf03..3263cd60a04e 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp @@ -1681,7 +1681,6 @@ buildNextBody(ZipOp op, OpBuilder &builder, Value initialState, /// into %initialState[0] : !iterators.state static Value buildCloseBody(ZipOp op, OpBuilder &builder, Value initialState, ArrayRef upstreamInfos) { - Location loc = op.getLoc(); ImplicitLocOpBuilder b(loc, builder); From e8e6ceb0c96e56426f059f9e97ad01b001c44bf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 12 Apr 2023 14:56:46 +0000 Subject: [PATCH 10/10] xxx-tablegen --- .../Dialect/Iterators/IR/IteratorsOps.td | 16 ++++++++-- .../lib/Dialect/Iterators/IR/Iterators.cpp | 30 +++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td b/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td index 27d246304618..9490fc7edb7a 100644 --- a/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td +++ b/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td @@ -258,15 +258,25 @@ def Iterators_MapOp : Iterators_Op<"map", }]; } -def Iterators_MergeJoinOp : Iterators_Op<"mergejoin", // XXX: add type constraint - [DeclareOpInterfaceMethods]> { +def Iterators_MergeJoinOp : Iterators_Op<"mergejoin", [ + PredOpTrait<"the element type of the result stream must be a tuple of the " + "two respective element types of the two input streams", + CPred<[{ + $result.getType().cast().getElementType() == + TupleType::get( + $result.getContext(), + TypeRange{$lhs.getType().cast().getElementType(), + $rhs.getType().cast().getElementType()})}]>>, + DeclareOpInterfaceMethods + ]> { let summary = "Join two sorted streams of tuples on their first element."; let description = [{ }]; let arguments = (ins Iterators_Stream:$lhs, Iterators_Stream:$rhs); let results = (outs Iterators_StreamOf:$result); let assemblyFormat = [{ - $lhs `and` $rhs attr-dict `:` functional-type(operands, results) + $lhs `and` $rhs attr-dict `:` type($result) + custom(type($lhs), type($rhs), ref(type($result))) }]; let extraClassDefinition = [{ /// Implement OpAsmOpInterface. diff --git a/experimental/iterators/lib/Dialect/Iterators/IR/Iterators.cpp b/experimental/iterators/lib/Dialect/Iterators/IR/Iterators.cpp index 87a68d6881fa..d14f60b96bc3 100644 --- a/experimental/iterators/lib/Dialect/Iterators/IR/Iterators.cpp +++ b/experimental/iterators/lib/Dialect/Iterators/IR/Iterators.cpp @@ -62,6 +62,36 @@ void IteratorsDialect::initialize() { // Iterators operations //===----------------------------------------------------------------------===// +static ParseResult parseJoinTypes(AsmParser &parser, Type &lhsType, + Type &rhsType, Type resultType) { + if (!resultType.isa()) { + return parser.emitError(parser.getNameLoc()) + << "expected result to be a StreamType"; + } + + auto elementType = resultType.cast().getElementType(); + if (!elementType.isa()) { + return parser.emitError(parser.getNameLoc()) + << "expected result to be a stream of TupleType"; + } + + auto tupleType = elementType.cast(); + if (tupleType.size() != 2) { + return parser.emitError(parser.getNameLoc()) + << "expected result to be a stream of TupleType with two fields"; + } + + MLIRContext *context = resultType.getContext(); + lhsType = StreamType::get(context, tupleType.getTypes()[0]); + rhsType = StreamType::get(context, tupleType.getTypes()[1]); + + return success(); +} + +static void printJoinTypes(AsmPrinter & /*printer*/, Operation * /*op*/, + Type /*lhsType*/, Type /*rhsType*/, + Type /*resultType*/) {} + static ParseResult parseInsertValueType(AsmParser & /*parser*/, Type &valueType, Type stateType, IntegerAttr indexAttr) { int64_t index = indexAttr.getValue().getSExtValue();