Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

ScheduleTree: replace comparison operators with named functions #583

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tc/core/polyhedral/schedule_isl_conversion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ std::unique_ptr<ScheduleTree> fromIslSchedule(isl::schedule schedule) {
// Note that the children of set and sequence nodes are always filters, so
// they cannot be replaced by empty trees.
bool validateSchedule(const ScheduleTree* st) {
return *st == *fromIslSchedule(toIslSchedule(st));
return st->treeEquals(fromIslSchedule(toIslSchedule(st)).get());
}

bool validateSchedule(isl::schedule sc) {
Expand Down
16 changes: 6 additions & 10 deletions tc/core/polyhedral/schedule_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,21 +336,17 @@ vector<const ScheduleTree*> ScheduleTree::collectDFSPreorder(
return functional::Filter(filterType, collectDFSPreorder(tree));
}

bool ScheduleTree::operator==(const ScheduleTree& other) const {
// ctx_ cmp ?
if (type_ != other.type_) {
bool ScheduleTree::treeEquals(const ScheduleTree* other) const {
if (!nodeEquals(other)) {
return false;
}
if (children_.size() != other.children_.size()) {
if (numChildren() != other->numChildren()) {
return false;
}
if (!elemEquals(this, &other, type_)) {
return false;
}
TC_CHECK(!other.as<ScheduleTreeSet>())
TC_CHECK(!other->as<ScheduleTreeSet>())
<< "NYI: ScheduleTreeType::Set comparison";
for (size_t i = 0; i < children_.size(); ++i) {
if (*children_[i] != *other.children_[i]) {
for (size_t i = 0, e = numChildren(); i < e; ++i) {
if (!child({i})->treeEquals(other->child({i}))) {
return false;
}
}
Expand Down
14 changes: 9 additions & 5 deletions tc/core/polyhedral/schedule_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,6 @@ struct ScheduleTree {
public:
virtual ~ScheduleTree();

bool operator==(const ScheduleTree& other) const;
bool operator!=(const ScheduleTree& other) const {
return !(*this == other);
}

// Swap a tree with with the given tree.
void swapChild(size_t pos, ScheduleTreeUPtr& swappee) {
TC_CHECK_GE(pos, 0u) << "position out of children bounds";
Expand Down Expand Up @@ -469,6 +464,15 @@ struct ScheduleTree {
// Note that this function does _not_ clone the child trees.
virtual ScheduleTreeUPtr clone() const = 0;

// Compare the current node to the "other" node.
// Note that this function does _not_ compare the child trees,
// use treeEquals() instead to compare entire trees.
virtual bool nodeEquals(const ScheduleTree* other) const = 0;

// Comapre the subtree rooted at the current node to the subtree
// rooted at "other".
bool treeEquals(const ScheduleTree* other) const;

//
// Data members
//
Expand Down
101 changes: 46 additions & 55 deletions tc/core/polyhedral/schedule_tree_elem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,21 +281,26 @@ ScheduleTreeThreadSpecificMarker::make(
return res;
}

bool ScheduleTreeBand::operator==(const ScheduleTreeBand& other) const {
if (permutable_ != other.permutable_) {
bool ScheduleTreeBand::nodeEquals(const ScheduleTreeBand* otherBand) const {
if (!otherBand) {
return false;
}
if (coincident_.size() != other.coincident_.size()) {
if (permutable_ != otherBand->permutable_) {
return false;
}
if (unroll_.size() != other.unroll_.size()) {
if (coincident_.size() != otherBand->coincident_.size()) {
return false;
}
if (unroll_.size() != otherBand->unroll_.size()) {
return false;
}
if (!std::equal(
coincident_.begin(), coincident_.end(), other.coincident_.begin())) {
coincident_.begin(),
coincident_.end(),
otherBand->coincident_.begin())) {
return false;
}
if (!std::equal(unroll_.begin(), unroll_.end(), other.unroll_.begin())) {
if (!std::equal(unroll_.begin(), unroll_.end(), otherBand->unroll_.begin())) {
return false;
}

Expand All @@ -305,13 +310,13 @@ bool ScheduleTreeBand::operator==(const ScheduleTreeBand& other) const {
// .domain() returns a zero-dimensional union set (in purely parameter space)
// if there is no explicit domain.
bool mupaIs0D = nMember() == 0;
bool otherMupaIs0D = other.nMember() == 0;
bool otherMupaIs0D = otherBand->nMember() == 0;
if (mupaIs0D ^ otherMupaIs0D) {
return false;
}
if (mupaIs0D && otherMupaIs0D) {
auto d1 = mupa_.domain();
auto d2 = other.mupa_.domain();
auto d2 = otherBand->mupa_.domain();
auto res = d1.is_equal(d2);
if (!res) {
LOG_IF(INFO, FLAGS_debug_tc_mapper)
Expand All @@ -322,7 +327,7 @@ bool ScheduleTreeBand::operator==(const ScheduleTreeBand& other) const {
}
} else {
auto m1 = isl::union_map::from(mupa_);
auto m2 = isl::union_map::from(other.mupa_);
auto m2 = isl::union_map::from(otherBand->mupa_);
{
auto res = m1.is_equal(m2);
if (!res) {
Expand All @@ -337,74 +342,60 @@ bool ScheduleTreeBand::operator==(const ScheduleTreeBand& other) const {
return true;
}

bool ScheduleTreeContext::operator==(const ScheduleTreeContext& other) const {
auto res = context_.is_equal(other.context_);
return res;
bool ScheduleTreeContext::nodeEquals(const ScheduleTreeContext* other) const {
return other && context_.is_equal(other->context_);
}

bool ScheduleTreeDomain::operator==(const ScheduleTreeDomain& other) const {
auto res = domain_.is_equal(other.domain_);
bool ScheduleTreeDomain::nodeEquals(const ScheduleTreeDomain* other) const {
if (!other) {
return false;
}
auto res = domain_.is_equal(other->domain_);
if (!res) {
LOG_IF(INFO, FLAGS_debug_tc_mapper)
<< "ScheduleTreeDomain difference: " << domain_ << " VS "
<< other.domain_ << "\n";
<< other->domain_ << "\n";
}
return res;
}

bool ScheduleTreeExtension::operator==(
const ScheduleTreeExtension& other) const {
auto res = extension_.is_equal(other.extension_);
return res;
bool ScheduleTreeExtension::nodeEquals(
const ScheduleTreeExtension* other) const {
return other && extension_.is_equal(other->extension_);
}

bool ScheduleTreeFilter::operator==(const ScheduleTreeFilter& other) const {
auto res = filter_.is_equal(other.filter_);
return res;
bool ScheduleTreeFilter::nodeEquals(const ScheduleTreeFilter* other) const {
return other && filter_.is_equal(other->filter_);
}

bool ScheduleTreeMapping::operator==(const ScheduleTreeMapping& other) const {
auto res = filter_.is_equal(other.filter_);
return res;
bool ScheduleTreeMapping::nodeEquals(const ScheduleTreeMapping* other) const {
if (mapping.size() != other->mapping.size()) {
return false;
}
for (const auto& kvp : mapping) {
if (other->mapping.count(kvp.first) == 0) {
return false;
}
if (!other->mapping.at(kvp.first).plain_is_equal(kvp.second)) {
return false;
}
}
return filter_.is_equal(other->filter_);
}

bool ScheduleTreeSequence::operator==(const ScheduleTreeSequence& other) const {
bool ScheduleTreeSequence::nodeEquals(const ScheduleTreeSequence* other) const {
return true;
}

bool ScheduleTreeSet::operator==(const ScheduleTreeSet& other) const {
bool ScheduleTreeSet::nodeEquals(const ScheduleTreeSet* other) const {
return true;
}

bool elemEquals(
const ScheduleTree* e1,
const ScheduleTree* e2,
detail::ScheduleTreeType type) {
#define ELEM_EQUALS_CASE(CLASS) \
else if (type == CLASS::NodeType) { \
return *static_cast<const CLASS*>(e1) == *static_cast<const CLASS*>(e2); \
}

if (type == detail::ScheduleTreeType::None) {
LOG(FATAL) << "Hit Error node!";
}
ELEM_EQUALS_CASE(ScheduleTreeBand)
ELEM_EQUALS_CASE(ScheduleTreeContext)
ELEM_EQUALS_CASE(ScheduleTreeDomain)
ELEM_EQUALS_CASE(ScheduleTreeExtension)
ELEM_EQUALS_CASE(ScheduleTreeFilter)
ELEM_EQUALS_CASE(ScheduleTreeMapping)
ELEM_EQUALS_CASE(ScheduleTreeSequence)
ELEM_EQUALS_CASE(ScheduleTreeSet)
else {
LOG(FATAL) << "NYI: ScheduleTree::operator== for type: "
<< static_cast<int>(type);
}

#undef ELEM_EQUALS_CASE

return false;
bool ScheduleTreeThreadSpecificMarker::nodeEquals(
const ScheduleTreeThreadSpecificMarker* other) const {
return true;
}

} // namespace detail
} // namespace polyhedral
} // namespace tc
Loading