Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bound #521

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open

Bound #521

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d4d492b
add API for split with flag
manya-bansal Oct 29, 2021
b765721
add test for splitUpDown and bound attributes for IndexVars
manya-bansal Nov 30, 2021
cbc872c
start substituting bound rel function calls
manya-bansal Jan 9, 2022
897bb04
add index notation falg
manya-bansal Jan 22, 2022
d6c5bb7
don't add such that node for bound relation
manya-bansal Jan 22, 2022
f2e5cfb
change derive iter
manya-bansal Jan 26, 2022
aa1a114
bound rel node deleted, change workspace tests to reflect changes
manya-bansal Jan 26, 2022
fc86897
api change + new tests
manya-bansal Feb 5, 2022
e74bf06
bounds test file
manya-bansal Feb 13, 2022
9ae9908
bound and rebound test
manya-bansal Feb 13, 2022
ba95326
adding tests for bound
manya-bansal Feb 13, 2022
2cf16af
add more tests
manya-bansal Feb 14, 2022
6dc9ad4
print prov graph
manya-bansal Feb 26, 2022
9ec3e6e
print prov graph
manya-bansal Feb 27, 2022
79f2d47
added an additional test
manya-bansal Feb 27, 2022
ba621eb
merge conflicts
manya-bansal Apr 10, 2022
12af135
change taco-cli-tests.bats to refkect new bound api
manya-bansal Apr 10, 2022
27288ff
ataco-cli-test passing
manya-bansal Apr 12, 2022
a52bb4e
remove assert
manya-bansal Apr 17, 2022
57dd389
check literal split
manya-bansal Apr 17, 2022
c27ef8a
check biunds against literal values
manya-bansal Apr 17, 2022
e21812b
change ir simplify logic + add additional ir tests
manya-bansal Apr 24, 2022
e258dca
remove split up and down
manya-bansal May 30, 2022
af4ff8f
change bound abstraction from indexVar to indexStmt
manya-bansal Jun 6, 2022
b437560
merge upstream
manya-bansal Jun 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,8 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
/// Preconditions:
/// The precondition for bound is that the computation bounds supplied are
/// correct given the inputs that this code will be run on.
IndexStmt bound(IndexVar i, IndexVar i1, size_t bound, BoundType bound_type) const;
// IndexStmt bound(IndexVar i, IndexVar i1, size_t bound, BoundType bound_type) const;
IndexStmt bound(IndexVar i, size_t bound, BoundType bound_type) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably remove the old bound call and replace all users with the new method (I don't think we care that much about backwards compat, and bound doesn't have that many users).


/// The unroll primitive unrolls the corresponding loop by a statically-known
/// integer number of iterations
Expand Down Expand Up @@ -1049,7 +1050,7 @@ class IndexVar : public IndexExpr, public IndexVarInterface {
/// Returns the name of the index variable.
std::string getName() const;

// Need these to overshadow the comparisons in for the IndexExpr instrusive pointer
/// Need these to overshadow the comparisons in for the IndexExpr instrusive pointer
friend bool operator==(const IndexVar&, const IndexVar&);
friend bool operator<(const IndexVar&, const IndexVar&);
friend bool operator!=(const IndexVar&, const IndexVar&);
Expand Down Expand Up @@ -1098,16 +1099,17 @@ class SuchThat : public IndexStmt {
public:
SuchThat() = default;
SuchThat(const SuchThatNode*);
SuchThat(IndexStmt stmt, std::vector<IndexVarRel> predicate);
SuchThat(IndexStmt stmt, std::vector<IndexVarRel> predicate, std::map<IndexVar, std::pair<size_t, BoundType>> boundsMap);

IndexStmt getStmt() const;
std::vector<IndexVarRel> getPredicate() const;
std::map<IndexVar, std::pair<size_t, BoundType>> getBounds() const;

typedef SuchThatNode Node;
};

/// Create a suchthat index statement.
SuchThat suchthat(IndexStmt stmt, std::vector<IndexVarRel> predicate);
SuchThat suchthat(IndexStmt stmt, std::vector<IndexVarRel> predicate, std::map<IndexVar, std::pair<size_t, BoundType>> boundsMap);

/// A tensor variable in an index expression, which can either be an operand
/// or the result of the expression.
Expand Down
4 changes: 3 additions & 1 deletion include/taco/index_notation/index_notation_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,16 @@ struct MultiNode : public IndexStmtNode {
};

struct SuchThatNode : public IndexStmtNode {
SuchThatNode(IndexStmt stmt, std::vector<IndexVarRel> predicate) : stmt(stmt), predicate(predicate) {}
SuchThatNode(IndexStmt stmt, std::vector<IndexVarRel> predicate, std::map<IndexVar, std::pair<size_t, BoundType>> boundsMap) : \
stmt(stmt), predicate(predicate), boundsMap(boundsMap) {}

void accept(IndexStmtVisitorStrict* v) const {
v->visit(this);
}

IndexStmt stmt;
std::vector<IndexVarRel> predicate;
std::map<IndexVar, std::pair<size_t, BoundType>> boundsMap;
};

struct SequenceNode : public IndexStmtNode {
Expand Down
49 changes: 15 additions & 34 deletions include/taco/index_notation/provenance_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace taco {
struct IndexVarRelNode;
enum IndexVarRelType {UNDEFINED, SPLIT, DIVIDE, POS, FUSE, BOUND, PRECOMPUTE};
enum IndexVarRelType {UNDEFINED, SPLIT, DIVIDE, POS, FUSE, PRECOMPUTE};

/// A pointer class for IndexVarRelNodes provides some operations for all IndexVarRelTypes
class IndexVarRel : public util::IntrusivePtr<const IndexVarRelNode> {
Expand Down Expand Up @@ -244,39 +244,6 @@ struct FuseRelNode : public IndexVarRelNode {

bool operator==(const FuseRelNode&, const FuseRelNode&);

/// The bound relation allows expressing a constraint or value known at compile-time that allows for compile-time optimizations
struct BoundRelNode : public IndexVarRelNode {
BoundRelNode(IndexVar parentVar, IndexVar boundVar, size_t bound, BoundType boundType);

const IndexVar& getParentVar() const;
const IndexVar& getBoundVar() const;
const size_t& getBound() const;
const BoundType& getBoundType() const;

void print(std::ostream& stream) const;
bool equals(const BoundRelNode &rel) const;
std::vector<IndexVar> getParents() const; // parentVar
std::vector<IndexVar> getChildren() const; // boundVar
std::vector<IndexVar> getIrregulars() const; // boundVar

/// Coordinate bounds remain unchanged, only iteration bounds change
std::vector<ir::Expr> computeRelativeBound(std::set<IndexVar> definedVars, std::map<IndexVar, std::vector<ir::Expr>> computedBounds, std::map<IndexVar, ir::Expr> variableExprs, Iterators iterators, ProvenanceGraph provGraph) const;

/// Constrained depending on bound_type
std::vector<ir::Expr> deriveIterBounds(IndexVar indexVar, std::map<IndexVar, std::vector<ir::Expr>> parentIterBounds, std::map<IndexVar, std::vector<ir::Expr>> parentCoordBounds, std::map<taco::IndexVar, taco::ir::Expr> variableNames, Iterators iterators, ProvenanceGraph provGraph) const;

/// parentVar = boundVar
ir::Expr recoverVariable(IndexVar indexVar, std::map<IndexVar, ir::Expr> variableNames, Iterators iterators, std::map<IndexVar, std::vector<ir::Expr>> parentIterBounds, std::map<IndexVar, std::vector<ir::Expr>> parentCoordBounds, ProvenanceGraph provGraph) const;

/// boundVar = parentVar
ir::Stmt recoverChild(IndexVar indexVar, std::map<IndexVar, ir::Expr> relVariables, bool emitVarDecl, Iterators iterators, ProvenanceGraph provGraph) const;
private:
struct Content;
std::shared_ptr<Content> content;
};

bool operator==(const BoundRelNode&, const BoundRelNode&);

/// The precompute relation allows creating a new precomputeVar that is iterated over for the precompute loop and shares same sizes as parentVar
/// This allows precomputeVar to be scheduled separately from the parentVar
struct PrecomputeRelNode : public IndexVarRelNode {
Expand Down Expand Up @@ -310,6 +277,7 @@ class ProvenanceGraph {
public:
ProvenanceGraph() {}
ProvenanceGraph(IndexStmt concreteStmt);


/// Returns the children of a given index variable, {} if no children or if indexVar is not in graph
std::vector<IndexVar> getChildren(IndexVar indexVar) const;
Expand Down Expand Up @@ -387,6 +355,15 @@ class ProvenanceGraph {
/// does the index variable have a descendant in position space
bool hasPosDescendant(IndexVar indexVar) const;

// /does the index variable have a bound
bool hasBound(IndexVar indexVar) const;

/// get the indexVar's bound
size_t getBound(IndexVar indexVar) const;

/// get the indexVar's boundType
taco::BoundType getBoundType(IndexVar indexVar) const;

/// does the index variable have an exact bound known at compile-time
bool hasExactBound(IndexVar indexVar) const;

Expand All @@ -411,13 +388,17 @@ class ProvenanceGraph {
/// a `.divide` scheduling operation.
bool isDivided(IndexVar indexVar) const;



private:
std::map<IndexVar, IndexVarRel> childRelMap;
std::map<IndexVar, IndexVarRel> parentRelMap;

std::map<IndexVar, std::vector<IndexVar>> parentsMap;
std::map<IndexVar, std::vector<IndexVar>> childrenMap;

std::map<IndexVar, std::pair<size_t, BoundType>> boundsMap;

std::set<IndexVar> nodes;
};

Expand Down
24 changes: 24 additions & 0 deletions include/taco/index_notation/transformations.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Reorder;
class Precompute;
class ForAllReplace;
class AddSuchThatPredicates;
class AddSuchThatBoundMap;
class Parallelize;
class TopoReorder;
class SetAssembleStrategy;
Expand All @@ -36,6 +37,7 @@ class Transformation {
Transformation(Parallelize);
Transformation(TopoReorder);
Transformation(AddSuchThatPredicates);
Transformation(AddSuchThatBoundMap);
Transformation(SetAssembleStrategy);
Transformation(SetMergeStrategy);

Expand Down Expand Up @@ -161,6 +163,28 @@ class AddSuchThatPredicates : public TransformationInterface {
std::ostream& operator<<(std::ostream&, const AddSuchThatPredicates&);


/// Adds a SuchThat node if it does not exist and adds the given BoundsList
class AddSuchThatBoundMap : public TransformationInterface {
public:
AddSuchThatBoundMap();

AddSuchThatBoundMap(std::map<IndexVar, std::pair<size_t, BoundType>> boundsMap);

std::map<IndexVar, std::pair<size_t, BoundType>> getBoundsMap() const;

IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const;

void print(std::ostream &os) const;

private:
struct Content;
std::shared_ptr<Content> content;
};

std::ostream& operator<<(std::ostream&, const AddSuchThatBoundMap&);



/// The parallelize optimization tags a Forall as parallelized
/// after checking for preconditions
class Parallelize : public TransformationInterface {
Expand Down
Loading