diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index 5f8693c2f..14620e4ca 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -383,67 +383,76 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const { Forall foralli(node); std::vector i_vars = precompute.getIVars(); - vector forallIndexVars; + bool containsWhere = false; match(foralli, - function([&](const ForallNode* op) { - forallIndexVars.push_back(op->indexVar); + function([&](const WhereNode* op) { + containsWhere = true; }) ); - IndexStmt s = foralli.getStmt(); - TensorVar ws = precompute.getWorkspace(); - IndexExpr e = precompute.getExpr(); - std::vector iw_vars = precompute.getIWVars(); + if (!containsWhere) { + vector forallIndexVars; + match(foralli, + function([&](const ForallNode* op) { + forallIndexVars.push_back(op->indexVar); + }) + ); - map substitutions; - taco_iassert(i_vars.size() == iw_vars.size()) << "i_vars and iw_vars lists must be the same size"; + IndexStmt s = foralli.getStmt(); + TensorVar ws = precompute.getWorkspace(); + IndexExpr e = precompute.getExpr(); + std::vector iw_vars = precompute.getIWVars(); - for (int index = 0; index < (int)i_vars.size(); index++) { - substitutions[i_vars[index]] = iw_vars[index]; - } + map substitutions; + taco_iassert(i_vars.size() == iw_vars.size()) << "i_vars and iw_vars lists must be the same size"; - // Build consumer by replacing with temporary (in replacedStmt) - IndexStmt replacedStmt = replace(s, {{e, ws(i_vars) }}); - if (replacedStmt != s) { - // Then modify the replacedStmt to have the correct foralls - // by concretizing the consumer assignment + for (int index = 0; index < (int)i_vars.size(); index++) { + substitutions[i_vars[index]] = iw_vars[index]; + } - auto consumerAssignment = getConsumerAssignment(replacedStmt, ws); - auto consumerIndexVars = consumerAssignment.getIndexVars(); + // Build consumer by replacing with temporary (in replacedStmt) + IndexStmt replacedStmt = replace(s, {{e, ws(i_vars) }}); + if (replacedStmt != s) { + // Then modify the replacedStmt to have the correct foralls + // by concretizing the consumer assignment - auto producerAssignment = getProducerAssignment(ws, i_vars, iw_vars, e, substitutions); - auto producerIndexVars = producerAssignment.getIndexVars(); + auto consumerAssignment = getConsumerAssignment(replacedStmt, ws); + auto consumerIndexVars = consumerAssignment.getIndexVars(); - vector producerForallIndexVars; - vector consumerForallIndexVars; - vector outerForallIndexVars; + auto producerAssignment = getProducerAssignment(ws, i_vars, iw_vars, e, substitutions); + auto producerIndexVars = producerAssignment.getIndexVars(); - bool stopForallDistribution = false; - for (auto &i : util::reverse(forallIndexVars)) { - if (!stopForallDistribution && containsIndexVarScheduled(i_vars, i)) { - producerForallIndexVars.push_back(substitutions[i]); - consumerForallIndexVars.push_back(i); - } else { - auto consumerContains = containsIndexVarScheduled(consumerIndexVars, i); - auto producerContains = containsIndexVarScheduled(producerIndexVars, i); - if (stopForallDistribution || (producerContains && consumerContains)) { - outerForallIndexVars.push_back(i); - stopForallDistribution = true; - } else if (!stopForallDistribution && consumerContains) { + vector producerForallIndexVars; + vector consumerForallIndexVars; + vector outerForallIndexVars; + + bool stopForallDistribution = false; + for (auto &i : util::reverse(forallIndexVars)) { + if (!stopForallDistribution && containsIndexVarScheduled(i_vars, i)) { + producerForallIndexVars.push_back(substitutions[i]); consumerForallIndexVars.push_back(i); - } else if (!stopForallDistribution && producerContains) { - producerForallIndexVars.push_back(i); + } else { + auto consumerContains = containsIndexVarScheduled(consumerIndexVars, i); + auto producerContains = containsIndexVarScheduled(producerIndexVars, i); + if (stopForallDistribution || (producerContains && consumerContains)) { + outerForallIndexVars.push_back(i); + stopForallDistribution = true; + } else if (!stopForallDistribution && consumerContains) { + consumerForallIndexVars.push_back(i); + } else if (!stopForallDistribution && producerContains) { + producerForallIndexVars.push_back(i); + } } } - } - IndexStmt consumer = generateForalls(consumerAssignment, consumerForallIndexVars); + IndexStmt consumer = generateForalls(consumerAssignment, consumerForallIndexVars); - IndexStmt producer = generateForalls(producerAssignment, producerForallIndexVars); - Where where(consumer, producer); + IndexStmt producer = generateForalls(producerAssignment, producerForallIndexVars); + Where where(consumer, producer); - stmt = generateForalls(where, outerForallIndexVars); - return; + stmt = generateForalls(where, outerForallIndexVars); + return; + } } IndexNotationRewriter::visit(node); } diff --git a/test/tests-workspaces.cpp b/test/tests-workspaces.cpp index ffc8e163c..80fd5f3f5 100644 --- a/test/tests-workspaces.cpp +++ b/test/tests-workspaces.cpp @@ -45,7 +45,7 @@ TEST(workspaces, tile_vecElemMul_NoTail) { expected.compile(); expected.assemble(); expected.compute(); - ASSERT_TENSOR_EQ(A, expected); + ASSERT_TENSOR_EQ(expected, A); } TEST(workspaces, tile_vecElemMul_Tail1) { @@ -83,7 +83,7 @@ TEST(workspaces, tile_vecElemMul_Tail1) { expected.compile(); expected.assemble(); expected.compute(); - ASSERT_TENSOR_EQ(A, expected); + ASSERT_TENSOR_EQ(expected, A); } TEST(workspaces, tile_vecElemMul_Tail2) { @@ -121,7 +121,7 @@ TEST(workspaces, tile_vecElemMul_Tail2) { expected.compile(); expected.assemble(); expected.compute(); - ASSERT_TENSOR_EQ(A, expected); + ASSERT_TENSOR_EQ(expected, A); // ir::IRPrinter irp = ir::IRPrinter(cout); // @@ -171,7 +171,7 @@ TEST(workspaces, tile_denseMatMul) { expected.compile(); expected.assemble(); expected.compute(); - ASSERT_TENSOR_EQ(A, expected); + ASSERT_TENSOR_EQ(expected, A); // ir::IRPrinter irp = ir::IRPrinter(cout); // @@ -218,7 +218,7 @@ TEST(workspaces, precompute2D_add) { expected.compile(); expected.assemble(); expected.compute(); - ASSERT_TENSOR_EQ(A, expected); + ASSERT_TENSOR_EQ(expected, A); } @@ -263,7 +263,7 @@ TEST(workspaces, precompute4D_add) { expected.compile(); expected.assemble(); expected.compute(); - ASSERT_TENSOR_EQ(A, expected); + ASSERT_TENSOR_EQ(expected, A); } TEST(workspaces, precompute4D_multireduce) { @@ -305,7 +305,7 @@ TEST(workspaces, precompute4D_multireduce) { expected.compile(); expected.assemble(); expected.compute(); - ASSERT_TENSOR_EQ(A, expected); + ASSERT_TENSOR_EQ(expected, A); } TEST(workspaces, precompute3D_TspV) { @@ -344,7 +344,7 @@ TEST(workspaces, precompute3D_TspV) { expected.compile(); expected.assemble(); expected.compute(); - ASSERT_TENSOR_EQ(A, expected); + ASSERT_TENSOR_EQ(expected, A); } @@ -388,7 +388,7 @@ TEST(workspaces, precompute3D_multipleWS) { expected.compile(); expected.assemble(); expected.compute(); - ASSERT_TENSOR_EQ(A, expected); + ASSERT_TENSOR_EQ(expected, A); } @@ -431,6 +431,178 @@ TEST(workspaces, precompute3D_renamedIVars_TspV) { expected.compile(); expected.assemble(); expected.compute(); - ASSERT_TENSOR_EQ(A, expected); + ASSERT_TENSOR_EQ(expected, A); } + +TEST(workspaces, DISABLED_tile_dotProduct_1) { + // FIXME: Disabled because currently the precompute algorithm does not appropriately + // find the correct forall substmt to next the WhereNode in after i has been + // split into i0 and i1. As an example, the first precompute below is incorrect + // since it should transform + // forall(i0, forall(i1, A() += B(i) * C(i))) --> + // forall(i0, where(forall(i1, A() += ws(i1)), forall(i1, ws(i1) += B(i) * C(i)))) + // + // But currently the algorithm does + // forall(i0, forall(i1, A() += B(i) * C(i))) --> + // where(forall(i1, A() += ws(i1)), forall(i0, forall(i1, ws(i1) += B(i) * C(i)))) + + int N = 1024; + Tensor A("A"); + Tensor B("B", {N}, Format({Dense})); + Tensor C("C", {N}, Format({Dense})); + + for (int i = 0; i < N; i++) { + B.insert({i}, (double) i); + C.insert({i}, (double) i); + } + + B.pack(); + C.pack(); + + IndexVar i("i"); + IndexVar i_bounded("i_bounded"); + IndexVar i0("i0"), i1("i1"); + IndexExpr BExpr = B(i); + IndexExpr CExpr = C(i); + IndexExpr precomputedExpr = (BExpr) * (CExpr); + A() = precomputedExpr; + + IndexStmt stmt = A.getAssignment().concretize(); + TensorVar B_new("B_new", Type(Float64, {(size_t)N}), taco::dense); + TensorVar C_new("C_new", Type(Float64, {(size_t)N}), taco::dense); + TensorVar precomputed("precomputed", Type(Float64, {(size_t)N}), taco::dense); + + stmt = stmt.bound(i, i_bounded, (size_t)N, BoundType::MaxExact) + .split(i_bounded, i0, i1, 32); + stmt = stmt.precompute(precomputedExpr, i1, i1, precomputed); + stmt = stmt.precompute(BExpr, i1, i1, B_new) + .precompute(CExpr, i1, i1, C_new); + + stmt = stmt.concretize(); + + A.compile(stmt); + A.assemble(); + A.compute(); + + ir::IRPrinter irp = ir::IRPrinter(cout); + + cout << stmt << endl; + + std::shared_ptr codegen = ir::CodeGen::init_default(cout, ir::CodeGen::ImplementationGen); + ir::Stmt compute = lower(stmt, "compute", false, true); + + irp.print(compute); + cout << endl; + codegen->compile(compute, false); + + Tensor expected("expected"); + expected() = B(i) * C(i); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + +TEST(workspaces, DISABLED_tile_dotProduct_2) { + // FIXME: This is also currently disabled since split(...) scheduling commands + // only split on the FIRST INSTANCE of an indexVar (assumes only one). + // This is wrong if the indexVar is not renamed across iw_vars since an indexVar can + // then occur on BOTH the consumer and producer side and should be split across both. + + int N = 1024; + Tensor A("A"); + Tensor B("B", {N}, Format({Dense})); + Tensor C("C", {N}, Format({Dense})); + + for (int i = 0; i < N; i++) { + B.insert({i}, (double) i); + C.insert({i}, (double) i); + } + + B.pack(); + C.pack(); + + IndexVar i("i"); + IndexVar i_bounded("i_bounded"); + IndexVar i0("i0"), i1("i1"); + IndexExpr BExpr = B(i); + IndexExpr CExpr = C(i); + IndexExpr precomputedExpr = (BExpr) * (CExpr); + A() = precomputedExpr; + + IndexStmt stmt = A.getAssignment().concretize(); + TensorVar B_new("B_new", Type(Float64, {(size_t)N}), taco::dense); + TensorVar C_new("C_new", Type(Float64, {(size_t)N}), taco::dense); + TensorVar precomputed("precomputed", Type(Float64, {(size_t)N}), taco::dense); + + stmt = stmt.precompute(precomputedExpr, i, i, precomputed); + + stmt = stmt.precompute(BExpr, i, i, B_new) + .precompute(CExpr, i, i, C_new); + + stmt = stmt.bound(i, i_bounded, (size_t)N, BoundType::MaxExact) + .split(i_bounded, i0, i1, 32); + + stmt = stmt.concretize(); + + A.compile(stmt); + A.assemble(); + A.compute(); + + Tensor expected("expected"); + expected() = B(i) * C(i); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + +TEST(workspaces, tile_dotProduct_3) { + int N = 1024; + Tensor A("A"); + Tensor B("B", {N}, Format({Dense})); + Tensor C("C", {N}, Format({Dense})); + + for (int i = 0; i < N; i++) { + B.insert({i}, (double) i); + C.insert({i}, (double) i); + } + + B.pack(); + C.pack(); + + IndexVar i("i"); + IndexVar i_bounded("i_bounded"); + IndexVar i0("i0"), i1("i1"); + IndexExpr BExpr = B(i); + IndexExpr CExpr = C(i); + IndexExpr precomputedExpr = (BExpr) * (CExpr); + A() = precomputedExpr; + + IndexStmt stmt = A.getAssignment().concretize(); + TensorVar B_new("B_new", Type(Float64, {(size_t)N}), taco::dense); + TensorVar C_new("C_new", Type(Float64, {(size_t)N}), taco::dense); + TensorVar precomputed("precomputed", Type(Float64, {(size_t)N}), taco::dense); + + stmt = stmt.bound(i, i_bounded, (size_t)N, BoundType::MaxExact) + .split(i_bounded, i0, i1, 32); + stmt = stmt.precompute(precomputedExpr, i0, i0, precomputed); + + stmt = stmt.precompute(BExpr, i1, i1, B_new) + .precompute(CExpr, i1, i1, C_new); + + + stmt = stmt.concretize(); + + A.compile(stmt); + A.assemble(); + A.compute(); + + Tensor expected("expected"); + expected() = B(i) * C(i); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +}