diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp index 175ff3363..872d3ad1c 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp @@ -137,7 +137,7 @@ struct GeqrfOpLowering : public OpRewritePattern { // `101` for row-major, `102` for col-major auto layout = rewriter.create( op.getLoc(), type_llvm_lapack_int, - rewriter.getIntegerAttr(type_lapack_int, 101)); + rewriter.getIntegerAttr(type_lapack_int, 102)); auto m = rewriter.create( op.getLoc(), type_llvm_lapack_int, rewriter.getIntegerAttr(type_lapack_int, inputShape[0])); @@ -183,7 +183,8 @@ struct GeqrfOpLowering : public OpRewritePattern { SmallVector aliases; for (int i = 0; i < 3; ++i) { - aliases.push_back(stablehlo::OutputOperandAliasAttr::get(ctx, {}, i, {})); + aliases.push_back( + stablehlo::OutputOperandAliasAttr::get(ctx, {i}, i, {})); } auto jit_call_op = rewriter.create( @@ -411,7 +412,7 @@ struct GeqrtOpLowering : public OpRewritePattern { // `101` for row-major, `102` for col-major auto layout = rewriter.create( op.getLoc(), type_llvm_lapack_int, - rewriter.getIntegerAttr(type_lapack_int, 101)); + rewriter.getIntegerAttr(type_lapack_int, 102)); auto m = rewriter.create( op.getLoc(), type_llvm_lapack_int, rewriter.getIntegerAttr(type_lapack_int, inputShape[0])); @@ -472,8 +473,8 @@ struct GeqrtOpLowering : public OpRewritePattern { op.getLoc(), type_T, cast(makeAttr(type_T, 0))); SmallVector isColMajorArr = {true, true, true}; - SmallVector operandRanks = {2, 1, 0}; - SmallVector outputRanks = {2, 1, 0}; + SmallVector operandRanks = {2, 2, 0}; + SmallVector outputRanks = {2, 2, 0}; auto operandLayouts = getSHLOLayout(rewriter, operandRanks, isColMajorArr, 2); auto resultLayouts = getSHLOLayout(rewriter, outputRanks, isColMajorArr, 2); @@ -622,7 +623,7 @@ struct OrgqrOpLowering : public OpRewritePattern { // `101` for row-major, `102` for col-major auto layout = rewriter.create( op.getLoc(), type_llvm_lapack_int, - rewriter.getIntegerAttr(type_lapack_int, 101)); + rewriter.getIntegerAttr(type_lapack_int, 102)); auto mC = inputShape[0]; auto m = rewriter.create( op.getLoc(), type_llvm_lapack_int, @@ -663,7 +664,7 @@ struct OrgqrOpLowering : public OpRewritePattern { auto resultLayouts = getSHLOLayout(rewriter, outputRanks, isColMajorArr, 2); SmallVector aliases; - aliases.push_back(stablehlo::OutputOperandAliasAttr::get(ctx, {0}, 0, {})); + aliases.push_back(stablehlo::OutputOperandAliasAttr::get(ctx, {}, 0, {})); auto jit_call_op = rewriter.create( op.getLoc(), TypeRange{inputType}, @@ -933,7 +934,7 @@ struct OrmqrOpLowering : public OpRewritePattern { // `101` for row-major, `102` for col-major auto layout = rewriter.create( op.getLoc(), type_llvm_lapack_int, - rewriter.getIntegerAttr(type_lapack_int, 101)); + rewriter.getIntegerAttr(type_lapack_int, 102)); auto side = rewriter.create( op.getLoc(), type_llvm_char, @@ -1200,7 +1201,7 @@ struct GemqrtOpLowering : public OpRewritePattern { // `101` for row-major, `102` for col-major auto layout = rewriter.create( op.getLoc(), type_llvm_lapack_int, - rewriter.getIntegerAttr(type_lapack_int, 101)); + rewriter.getIntegerAttr(type_lapack_int, 102)); auto side = rewriter.create( op.getLoc(), type_llvm_char, diff --git a/test/lit_tests/linalg/gemqrt_square.mlir b/test/lit_tests/linalg/gemqrt_square.mlir index b57219111..92e703f8b 100644 --- a/test/lit_tests/linalg/gemqrt_square.mlir +++ b/test/lit_tests/linalg/gemqrt_square.mlir @@ -8,7 +8,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgemqrt_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(76 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(78 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(64 : i64) : i64 @@ -30,7 +30,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgemqrt_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(76 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(84 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(64 : i64) : i64 @@ -52,7 +52,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgemqrt_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(82 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(78 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(48 : i64) : i64 @@ -74,7 +74,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgemqrt_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(82 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(84 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(48 : i64) : i64 diff --git a/test/lit_tests/linalg/gemqrt_tall.mlir b/test/lit_tests/linalg/gemqrt_tall.mlir index 2a3c6b265..672de8342 100644 --- a/test/lit_tests/linalg/gemqrt_tall.mlir +++ b/test/lit_tests/linalg/gemqrt_tall.mlir @@ -8,7 +8,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgemqrt_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(76 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(78 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(64 : i64) : i64 @@ -31,7 +31,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgemqrt_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(76 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(84 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(64 : i64) : i64 @@ -54,7 +54,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgemqrt_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(82 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(78 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(48 : i64) : i64 @@ -77,7 +77,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgemqrt_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(82 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(84 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(48 : i64) : i64 diff --git a/test/lit_tests/linalg/geqrf_square.mlir b/test/lit_tests/linalg/geqrf_square.mlir index 2d83db944..af026f461 100644 --- a/test/lit_tests/linalg/geqrf_square.mlir +++ b/test/lit_tests/linalg/geqrf_square.mlir @@ -10,7 +10,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgeqrf_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(64 : i64) : i64 // CPU-NEXT: %2 = llvm.call @enzymexla_lapacke_sgeqrf_(%0, %1, %1, %arg0, %1, %arg1) : (i64, i64, i64, !llvm.ptr, i64, !llvm.ptr) -> i64 // CPU-NEXT: llvm.store %2, %arg2 : i64, !llvm.ptr @@ -20,7 +20,7 @@ module { // CPU-NEXT: func.func @main(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64xf32>, tensor) { // CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor // CPU-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<64xf32> -// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sgeqrf_[[WRAPPER_ID]] (%arg0, %cst, %c) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<64x64xf32>, tensor<64xf32>, tensor) -> (tensor<64x64xf32>, tensor<64xf32>, tensor) +// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sgeqrf_[[WRAPPER_ID]] (%arg0, %cst, %c) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<64x64xf32>, tensor<64xf32>, tensor) -> (tensor<64x64xf32>, tensor<64xf32>, tensor) // CPU-NEXT: return %0#0, %0#1, %0#2 : tensor<64x64xf32>, tensor<64xf32>, tensor // CPU-NEXT: } diff --git a/test/lit_tests/linalg/geqrf_tall_thin.mlir b/test/lit_tests/linalg/geqrf_tall_thin.mlir index 466cc7506..02b66e740 100644 --- a/test/lit_tests/linalg/geqrf_tall_thin.mlir +++ b/test/lit_tests/linalg/geqrf_tall_thin.mlir @@ -10,7 +10,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgeqrf_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(64 : i64) : i64 // CPU-NEXT: %2 = llvm.mlir.constant(32 : i64) : i64 // CPU-NEXT: %3 = llvm.call @enzymexla_lapacke_sgeqrf_(%0, %1, %2, %arg0, %1, %arg1) : (i64, i64, i64, !llvm.ptr, i64, !llvm.ptr) -> i64 @@ -21,7 +21,7 @@ module { // CPU-NEXT: func.func @main(%arg0: tensor<64x32xf32>) -> (tensor<64x32xf32>, tensor<32xf32>, tensor) { // CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor // CPU-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<32xf32> -// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sgeqrf_[[WRAPPER_ID]] (%arg0, %cst, %c) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<64x32xf32>, tensor<32xf32>, tensor) -> (tensor<64x32xf32>, tensor<32xf32>, tensor) +// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sgeqrf_[[WRAPPER_ID]] (%arg0, %cst, %c) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<64x32xf32>, tensor<32xf32>, tensor) -> (tensor<64x32xf32>, tensor<32xf32>, tensor) // CPU-NEXT: return %0#0, %0#1, %0#2 : tensor<64x32xf32>, tensor<32xf32>, tensor // CPU-NEXT: } diff --git a/test/lit_tests/linalg/geqrf_wide_thin.mlir b/test/lit_tests/linalg/geqrf_wide_thin.mlir index ebb9365d7..7d4a98c39 100644 --- a/test/lit_tests/linalg/geqrf_wide_thin.mlir +++ b/test/lit_tests/linalg/geqrf_wide_thin.mlir @@ -10,7 +10,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgeqrf_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(32 : i64) : i64 // CPU-NEXT: %2 = llvm.mlir.constant(64 : i64) : i64 // CPU-NEXT: %3 = llvm.call @enzymexla_lapacke_sgeqrf_(%0, %1, %2, %arg0, %1, %arg1) : (i64, i64, i64, !llvm.ptr, i64, !llvm.ptr) -> i64 @@ -21,7 +21,7 @@ module { // CPU-NEXT: func.func @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32>, tensor<32xf32>, tensor) { // CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor // CPU-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<32xf32> -// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sgeqrf_[[WRAPPER_ID]] (%arg0, %cst, %c) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<32x64xf32>, tensor<32xf32>, tensor) -> (tensor<32x64xf32>, tensor<32xf32>, tensor) +// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sgeqrf_[[WRAPPER_ID]] (%arg0, %cst, %c) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<32x64xf32>, tensor<32xf32>, tensor) -> (tensor<32x64xf32>, tensor<32xf32>, tensor) // CPU-NEXT: return %0#0, %0#1, %0#2 : tensor<32x64xf32>, tensor<32xf32>, tensor // CPU-NEXT: } diff --git a/test/lit_tests/linalg/geqrt_square.mlir b/test/lit_tests/linalg/geqrt_square.mlir index f9224e2e0..11981a122 100644 --- a/test/lit_tests/linalg/geqrt_square.mlir +++ b/test/lit_tests/linalg/geqrt_square.mlir @@ -8,7 +8,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgeqrt_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(64 : i64) : i64 // CPU-NEXT: %2 = llvm.call @enzymexla_lapacke_sgeqrt_(%0, %1, %1, %1, %arg0, %1, %arg1, %1) : (i64, i64, i64, i64, !llvm.ptr, i64, !llvm.ptr, i64) -> i64 // CPU-NEXT: llvm.store %2, %arg2 : i64, !llvm.ptr @@ -18,7 +18,7 @@ module { // CPU-NEXT: func.func @main(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32>, tensor) { // CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor // CPU-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32> -// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sgeqrt_[[WRAPPER_ID]] (%arg0, %cst, %c) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<64x64xf32>, tensor<64x64xf32>, tensor) -> (tensor<64x64xf32>, tensor<64x64xf32>, tensor) +// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sgeqrt_[[WRAPPER_ID]] (%arg0, %cst, %c) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<64x64xf32>, tensor<64x64xf32>, tensor) -> (tensor<64x64xf32>, tensor<64x64xf32>, tensor) // CPU-NEXT: return %0#0, %0#1, %0#2 : tensor<64x64xf32>, tensor<64x64xf32>, tensor // CPU-NEXT: } diff --git a/test/lit_tests/linalg/geqrt_tall_thin.mlir b/test/lit_tests/linalg/geqrt_tall_thin.mlir index a9cdaf310..130b9c29a 100644 --- a/test/lit_tests/linalg/geqrt_tall_thin.mlir +++ b/test/lit_tests/linalg/geqrt_tall_thin.mlir @@ -8,7 +8,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgeqrt_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(64 : i64) : i64 // CPU-NEXT: %2 = llvm.mlir.constant(32 : i64) : i64 // CPU-NEXT: %3 = llvm.call @enzymexla_lapacke_sgeqrt_(%0, %1, %2, %2, %arg0, %1, %arg1, %2) : (i64, i64, i64, i64, !llvm.ptr, i64, !llvm.ptr, i64) -> i64 @@ -19,7 +19,7 @@ module { // CPU-NEXT: func.func @main(%arg0: tensor<64x32xf32>) -> (tensor<64x32xf32>, tensor<32x32xf32>, tensor) { // CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor // CPU-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<32x32xf32> -// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sgeqrt_[[WRAPPER_ID]] (%arg0, %cst, %c) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<64x32xf32>, tensor<32x32xf32>, tensor) -> (tensor<64x32xf32>, tensor<32x32xf32>, tensor) +// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sgeqrt_[[WRAPPER_ID]] (%arg0, %cst, %c) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<64x32xf32>, tensor<32x32xf32>, tensor) -> (tensor<64x32xf32>, tensor<32x32xf32>, tensor) // CPU-NEXT: return %0#0, %0#1, %0#2 : tensor<64x32xf32>, tensor<32x32xf32>, tensor // CPU-NEXT: } diff --git a/test/lit_tests/linalg/geqrt_wide_thin.mlir b/test/lit_tests/linalg/geqrt_wide_thin.mlir index 796159989..0b64ce7c9 100644 --- a/test/lit_tests/linalg/geqrt_wide_thin.mlir +++ b/test/lit_tests/linalg/geqrt_wide_thin.mlir @@ -8,7 +8,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sgeqrt_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(32 : i64) : i64 // CPU-NEXT: %2 = llvm.mlir.constant(64 : i64) : i64 // CPU-NEXT: %3 = llvm.call @enzymexla_lapacke_sgeqrt_(%0, %1, %2, %1, %arg0, %1, %arg1, %1) : (i64, i64, i64, i64, !llvm.ptr, i64, !llvm.ptr, i64) -> i64 @@ -19,7 +19,7 @@ module { // CPU-NEXT: func.func @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32>, tensor<32x32xf32>, tensor) { // CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor // CPU-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<32x32xf32> -// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sgeqrt_4 (%arg0, %cst, %c) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<32x64xf32>, tensor<32x32xf32>, tensor) -> (tensor<32x64xf32>, tensor<32x32xf32>, tensor) +// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sgeqrt_4 (%arg0, %cst, %c) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<32x64xf32>, tensor<32x32xf32>, tensor) -> (tensor<32x64xf32>, tensor<32x32xf32>, tensor) // CPU-NEXT: return %0#0, %0#1, %0#2 : tensor<32x64xf32>, tensor<32x32xf32>, tensor // CPU-NEXT: } diff --git a/test/lit_tests/linalg/orgqr_square.mlir b/test/lit_tests/linalg/orgqr_square.mlir index d9aec0b9b..14dc53343 100644 --- a/test/lit_tests/linalg/orgqr_square.mlir +++ b/test/lit_tests/linalg/orgqr_square.mlir @@ -10,14 +10,14 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sorgqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(64 : i64) : i64 // CPU-NEXT: %2 = llvm.call @enzymexla_lapacke_sorgqr_(%0, %1, %1, %1, %arg0, %1, %arg1) : (i64, i64, i64, i64, !llvm.ptr, i64, !llvm.ptr) -> i64 // CPU-NEXT: llvm.return // CPU-NEXT: } // CPU-NEXT: llvm.func @enzymexla_lapacke_sorgqr_(i64, i64, i64, i64, !llvm.ptr, i64, !llvm.ptr) -> i64 // CPU-NEXT: func.func @main(%arg0: tensor<64x64xf32>, %arg1: tensor<64xf32>) -> tensor<64x64xf32> { -// CPU-NEXT: %0 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sorgqr_[[WRAPPER_ID]] (%arg0, %arg1) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>], xla_side_effect_free} : (tensor<64x64xf32>, tensor<64xf32>) -> tensor<64x64xf32> +// CPU-NEXT: %0 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sorgqr_[[WRAPPER_ID]] (%arg0, %arg1) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>], xla_side_effect_free} : (tensor<64x64xf32>, tensor<64xf32>) -> tensor<64x64xf32> // CPU-NEXT: return %0 : tensor<64x64xf32> // CPU-NEXT: } diff --git a/test/lit_tests/linalg/orgqr_tall.mlir b/test/lit_tests/linalg/orgqr_tall.mlir index 7a2fafeae..b57d1a929 100644 --- a/test/lit_tests/linalg/orgqr_tall.mlir +++ b/test/lit_tests/linalg/orgqr_tall.mlir @@ -10,7 +10,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sorgqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(64 : i64) : i64 // CPU-NEXT: %2 = llvm.mlir.constant(32 : i64) : i64 // CPU-NEXT: %3 = llvm.call @enzymexla_lapacke_sorgqr_(%0, %1, %2, %2, %arg0, %1, %arg1) : (i64, i64, i64, i64, !llvm.ptr, i64, !llvm.ptr) -> i64 @@ -18,7 +18,7 @@ module { // CPU-NEXT: } // CPU-NEXT: llvm.func @enzymexla_lapacke_sorgqr_(i64, i64, i64, i64, !llvm.ptr, i64, !llvm.ptr) -> i64 // CPU-NEXT: func.func @main(%arg0: tensor<64x32xf32>, %arg1: tensor<32xf32>) -> tensor<64x32xf32> { -// CPU-NEXT: %0 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sorgqr_[[WRAPPER_ID]] (%arg0, %arg1) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>], xla_side_effect_free} : (tensor<64x32xf32>, tensor<32xf32>) -> tensor<64x32xf32> +// CPU-NEXT: %0 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sorgqr_[[WRAPPER_ID]] (%arg0, %arg1) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>], xla_side_effect_free} : (tensor<64x32xf32>, tensor<32xf32>) -> tensor<64x32xf32> // CPU-NEXT: return %0 : tensor<64x32xf32> // CPU-NEXT: } diff --git a/test/lit_tests/linalg/orgqr_wide.mlir b/test/lit_tests/linalg/orgqr_wide.mlir index f1229f517..7b794b0d6 100644 --- a/test/lit_tests/linalg/orgqr_wide.mlir +++ b/test/lit_tests/linalg/orgqr_wide.mlir @@ -10,7 +10,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sorgqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(32 : i64) : i64 // CPU-NEXT: %2 = llvm.mlir.constant(64 : i64) : i64 // CPU-NEXT: %3 = llvm.call @enzymexla_lapacke_sorgqr_(%0, %1, %2, %2, %arg0, %1, %arg1) : (i64, i64, i64, i64, !llvm.ptr, i64, !llvm.ptr) -> i64 @@ -18,7 +18,7 @@ module { // CPU-NEXT: } // CPU-NEXT: llvm.func @enzymexla_lapacke_sorgqr_(i64, i64, i64, i64, !llvm.ptr, i64, !llvm.ptr) -> i64 // CPU-NEXT: func.func @main(%arg0: tensor<32x64xf32>, %arg1: tensor<32xf32>) -> tensor<32x64xf32> { -// CPU-NEXT: %0 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sorgqr_[[WRAPPER_ID]] (%arg0, %arg1) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>], xla_side_effect_free} : (tensor<32x64xf32>, tensor<32xf32>) -> tensor<32x64xf32> +// CPU-NEXT: %0 = enzymexla.jit_call @enzymexla_wrapper_lapacke_sorgqr_[[WRAPPER_ID]] (%arg0, %arg1) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>], xla_side_effect_free} : (tensor<32x64xf32>, tensor<32xf32>) -> tensor<32x64xf32> // CPU-NEXT: return %0 : tensor<32x64xf32> // CPU-NEXT: } diff --git a/test/lit_tests/linalg/ormqr_square.mlir b/test/lit_tests/linalg/ormqr_square.mlir index 03ec34187..157e676eb 100644 --- a/test/lit_tests/linalg/ormqr_square.mlir +++ b/test/lit_tests/linalg/ormqr_square.mlir @@ -8,7 +8,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sormqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(76 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(78 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(64 : i64) : i64 @@ -29,7 +29,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sormqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(76 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(84 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(64 : i64) : i64 @@ -50,7 +50,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sormqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(82 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(78 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(64 : i64) : i64 @@ -71,7 +71,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sormqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(82 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(84 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(64 : i64) : i64 diff --git a/test/lit_tests/linalg/ormqr_tall.mlir b/test/lit_tests/linalg/ormqr_tall.mlir index 2cb6e4c56..d2a00eaab 100644 --- a/test/lit_tests/linalg/ormqr_tall.mlir +++ b/test/lit_tests/linalg/ormqr_tall.mlir @@ -8,7 +8,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sormqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(76 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(78 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(64 : i64) : i64 @@ -31,7 +31,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sormqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(76 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(84 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(64 : i64) : i64 @@ -54,7 +54,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sormqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(82 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(78 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(48 : i64) : i64 diff --git a/test/lit_tests/linalg/ormqr_wide.mlir b/test/lit_tests/linalg/ormqr_wide.mlir index 736fc6543..d9d9cedd0 100644 --- a/test/lit_tests/linalg/ormqr_wide.mlir +++ b/test/lit_tests/linalg/ormqr_wide.mlir @@ -11,7 +11,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sormqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(76 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(78 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(32 : i64) : i64 @@ -35,7 +35,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sormqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(76 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(84 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(32 : i64) : i64 @@ -59,7 +59,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sormqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(82 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(78 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(64 : i64) : i64 @@ -83,7 +83,7 @@ module { } // CPU: llvm.func @enzymexla_wrapper_lapacke_sormqr_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(101 : i64) : i64 +// CPU-NEXT: %0 = llvm.mlir.constant(102 : i64) : i64 // CPU-NEXT: %1 = llvm.mlir.constant(82 : i8) : i8 // CPU-NEXT: %2 = llvm.mlir.constant(84 : i8) : i8 // CPU-NEXT: %3 = llvm.mlir.constant(64 : i64) : i64