Skip to content

Multidimensional dense workspaces #475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,10 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
/// Takes any index notation and concretizes unknowns to make it concrete notation
IndexStmt concretize() const;

/// Takes any index notation and concretizes unknowns to make it concrete notation
/// given a Provenance Graph of indexVars
IndexStmt concretizeScheduled(ProvenanceGraph provGraph, std::vector<IndexVar> forallIndexVarList) const;

/// The \code{split} transformation splits (strip-mines) an index
/// variable into two nested index variables, where the size of the
/// inner index variable is constant. The size of the outer index
Expand Down Expand Up @@ -681,6 +685,12 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
/// reorder computations to increase locality
IndexStmt precompute(IndexExpr expr, IndexVar i, IndexVar iw, TensorVar workspace) const;

/// The precompute transformation is described in kjolstad2019
/// allows us to leverage scratchpad memories and
/// reorder computations to increase locality
IndexStmt precompute(IndexExpr expr, std::vector<IndexVar> i_vars,
std::vector<IndexVar> iw_vars, TensorVar workspace) const;

/// bound specifies a compile-time constraint on an index variable's
/// iteration space that allows knowledge of the
/// size or structured sparsity pattern of the inputs to be
Expand Down Expand Up @@ -1119,6 +1129,10 @@ bool isEinsumNotation(IndexStmt, std::string* reason=nullptr);
/// notation is printed to.
bool isReductionNotation(IndexStmt, std::string* reason=nullptr);

/// Check whether the statement is in the reduction index notation dialect
/// given a schedule described by the Provenance Graph
bool isReductionNotationScheduled(IndexStmt, ProvenanceGraph, std::string* reason=nullptr);

/// Check whether the statement is in the concrete index notation dialect.
/// This means every index variable has a forall node, there are no reduction
/// nodes, and that every reduction variable use is nested inside a compound
Expand All @@ -1136,6 +1150,18 @@ IndexStmt makeReductionNotation(IndexStmt);
/// as needed.
IndexStmt makeConcreteNotation(IndexStmt);


/// Convert einsum notation to reduction notation, by applying Einstein's
/// summation convention to sum non-free/reduction variables over their term
/// taking into account a schedule given by the Provenance Graph.
Assignment makeReductionNotationScheduled(Assignment, ProvenanceGraph);
IndexStmt makeReductionNotationScheduled(IndexStmt, ProvenanceGraph);

/// Convert reduction notation to concrete notation, by inserting forall nodes,
/// replacing reduction nodes by compound assignments, and inserting temporaries
/// as needed while taking into account a schedule given by the Provenance Graph.
IndexStmt makeConcreteNotationScheduled(IndexStmt, ProvenanceGraph, std::vector<IndexVar> forallIndexVars);

/// Returns the results of the index statement, in the order they appear.
std::vector<TensorVar> getResults(IndexStmt stmt);

Expand Down
8 changes: 5 additions & 3 deletions include/taco/index_notation/transformations.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ class Precompute : public TransformationInterface {
public:
Precompute();
Precompute(IndexExpr expr, IndexVar i, IndexVar iw, TensorVar workspace);

Precompute(IndexExpr expr, std::vector<IndexVar> i_vars,
std::vector<IndexVar> iw_vars, TensorVar workspace);

IndexExpr getExpr() const;
IndexVar geti() const;
IndexVar getiw() const;
std::vector<IndexVar>& getIVars() const;
std::vector<IndexVar>& getIWVars() const;
TensorVar getWorkspace() const;

/// Apply the precompute optimization to a concrete index statement.
Expand Down
23 changes: 23 additions & 0 deletions include/taco/ir/workspace_rewriter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef TACO_WORKSPACE_REWRITER_H
#define TACO_WORKSPACE_REWRITER_H

#include <vector>
#include <map>


namespace taco {
class TensorVar;

namespace ir {
class Stmt;
class Expr;
}

/// Rewrite a post-lowered IR statement to take into account multidimensional temporaries.
/// Replaces Dimension GetProperty nodes that correspond to temporary workspaces with
/// their corresponding dimension found in the temporarySizeMap.
ir::Stmt rewriteTemporaryGP(const ir::Stmt& stmt, std::vector<TensorVar> whereTemps,
std::map<TensorVar, std::vector<ir::Expr>> temporarySizeMap);

}
#endif
6 changes: 6 additions & 0 deletions include/taco/lower/lowerer_impl_imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,12 @@ class LowererImplImperative : public LowererImpl {
std::vector<TensorVar> whereTemps;
std::map<TensorVar, const AccessNode *> whereTempsToResult;

// Map temporary tensorVars to a list of size expressions for each mode
std::map<TensorVar, std::vector<ir::Expr>> temporarySizeMap;

// List that contains all temporary tensorVars
std::vector<TensorVar> temporaries;

bool captureNextLocatePos = false;
ir::Stmt capturedLocatePos; // used for whereConsumer when want to replicate same locating

Expand Down
2 changes: 2 additions & 0 deletions include/taco/parser/schedule_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace parser {
// [ [ "reorder", "i", "j" ], [ "precompute", "D(i,j)*E(j,k)", "j", "j_pre" ] ]
std::vector<std::vector<std::string>> ScheduleParser(const std::string);

std::vector<std::string> varListParser(const std::string);

// serialize the result of a parse (for debugging)
std::string serializeParsedSchedule(std::vector<std::vector<std::string>>);

Expand Down
8 changes: 6 additions & 2 deletions src/error/error_checks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ std::pair<bool, string> dimensionsTypecheck(const std::vector<IndexVar>& resultV
for (size_t mode = 0; mode < resultVars.size(); mode++) {
IndexVar var = resultVars[mode];
auto dimension = shape.getDimension(mode);
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension &&
!(indexVarDims.at(var).isIndexVarSized() && indexVarDims.at(var).getIndexVarSize() == var) &&
!(dimension.isIndexVarSized() && dimension.getIndexVarSize() == var)) {
errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension));
} else {
indexVarDims.insert({var, dimension});
Expand All @@ -63,7 +65,9 @@ std::pair<bool, string> dimensionsTypecheck(const std::vector<IndexVar>& resultV
dimension = Dimension(a.getIndexSet(mode).size());
}

if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension &&
!(indexVarDims.at(var).isIndexVarSized() && indexVarDims.at(var).getIndexVarSize() == var) &&
!(dimension.isIndexVarSized() && dimension.getIndexVarSize() == var)) {
errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension));
} else {
indexVarDims.insert({var, dimension});
Expand Down
Loading