Skip to content

Commit afde9fa

Browse files
authored
Generate a deep-compare "equiv" function for IR nodes (#1294)
* add vector/namemap/nodemap support * add gtest
1 parent 05b9dbc commit afde9fa

File tree

11 files changed

+155
-68
lines changed

11 files changed

+155
-68
lines changed

ir/expression.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,11 @@ class MethodCallExpression : Expression {
337337
MethodCallExpression(Util::SourceInfo si, Expression m,
338338
const std::initializer_list<Argument> &a)
339339
: Expression(si), method(m), arguments(new Vector<Argument>(a)) {}
340+
MethodCallExpression(Expression m, const std::initializer_list<const Expression *> &a)
341+
: method(m), arguments(nullptr) {
342+
auto arguments = new Vector<Argument>;
343+
for (auto arg : a) arguments->push_back(new Argument(arg));
344+
this->arguments = arguments; }
340345
}
341346

342347
class ConstructorCallExpression : Expression {

ir/namemap.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,16 @@ class NameMap : public Node {
113113
IRNODE_SUBCLASS(NameMap)
114114
bool operator==(const Node &a) const override { return a == *this; }
115115
bool operator==(const NameMap &a) const { return symbols == a.symbols; }
116+
bool equiv(const Node &a_) const override {
117+
if (static_cast<const Node *>(this) == &a_) return true;
118+
if (typeid(*this) != typeid(a_)) return false;
119+
auto &a = static_cast<const NameMap<T, MAP, COMP, ALLOC> &>(a_);
120+
if (size() != a.size()) return false;
121+
auto it = a.begin();
122+
for (auto &el : *this)
123+
if (el.first != it->first || !el.second->equiv(*(it++)->second))
124+
return false;
125+
return true; }
116126
cstring node_type_name() const override {
117127
return "NameMap<" + T::static_type_name() + ">"; }
118128
static cstring static_type_name() {

ir/node.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,12 @@ class Node : public virtual INode {
109109
void toJSON(JSONGenerator &json) const override;
110110
void sourceInfoToJSON(JSONGenerator &json) const;
111111
Util::JsonObject* sourceInfoJsonObj() const;
112+
/* operator== does a 'shallow' comparison, comparing two Node subclass objects for equality,
113+
* and comparing pointers in the Node directly for equality */
112114
virtual bool operator==(const Node &a) const { return typeid(*this) == typeid(a); }
115+
/* 'equiv' does a deep-equals comparison, comparing all non-pointer fields and recursing
116+
* though all Node subclass pointers to compare them with 'equiv' as well. */
117+
virtual bool equiv(const Node &a) const { return typeid(*this) == typeid(a); }
113118
#define DEFINE_OPEQ_FUNC(CLASS, BASE) \
114119
virtual bool operator==(const CLASS &) const { return false; }
115120
IRNODE_ALL_SUBCLASSES(DEFINE_OPEQ_FUNC)

ir/nodemap.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright 2013-present Barefoot Networks, Inc.
2+
Copyright 2013-present Barefoot Networks, Inc.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -72,6 +72,16 @@ class NodeMap : public Node {
7272
IRNODE_SUBCLASS(NodeMap)
7373
bool operator==(const Node &a) const override { return a == *this; }
7474
bool operator==(const NodeMap &a) const { return symbols == a.symbols; }
75+
bool equiv(const Node &a_) const override {
76+
if (static_cast<const Node *>(this) == &a_) return true;
77+
if (typeid(*this) != typeid(a_)) return false;
78+
auto &a = static_cast<const NodeMap<KEY, VALUE, MAP, COMP, ALLOC> &>(a_);
79+
if (size() != a.size()) return false;
80+
auto it = a.begin();
81+
for (auto &el : *this)
82+
if (el.first != it->first || !el.second->equiv(*(it++)->second))
83+
return false;
84+
return true; }
7585
cstring node_type_name() const override {
7686
return "NodeMap<" + KEY::static_type_name() + "," + VALUE::static_type_name() + ">"; }
7787
static cstring static_type_name() {

ir/vector.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,14 @@ class Vector : public VectorBase {
165165
* of IR::Vector that appear somewhere in a .def file -- you can usually make
166166
* it work by using an instantiation with an (abstract) base class rather
167167
* than a concrete class, as most of those appear in .def files. */
168-
168+
bool equiv(const Node &a_) const override {
169+
if (static_cast<const Node *>(this) == &a_) return true;
170+
if (typeid(*this) != typeid(a_)) return false;
171+
auto &a = static_cast<const Vector<T> &>(a_);
172+
if (size() != a.size()) return false;
173+
auto it = a.begin();
174+
for (auto *el : *this) if (!el->equiv(**it++)) return false;
175+
return true; }
169176
cstring node_type_name() const override {
170177
return "Vector<" + T::static_type_name() + ">"; }
171178
static cstring static_type_name() {

midend/local_copyprop.cpp

Lines changed: 2 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -302,69 +302,8 @@ IR::AssignmentStatement *DoLocalCopyPropagation::preorder(IR::AssignmentStatemen
302302
return postorder(as);
303303
}
304304

305-
/* Function to check if two expressions are equivalent -
306-
* used to remove no-op assignments.
307-
* FIXME -- This function only covers a limited subset of expressions,
308-
* but will not output true for expressions which are not equivalent.
309-
* Proper deep comparisons may be required */
310-
bool DoLocalCopyPropagation::equiv(const IR::Expression *left, const IR::Expression *right) {
311-
// Compare names of variables (at this pass, all names are unique)
312-
auto pl = left->to<IR::PathExpression>();
313-
auto pr = right->to<IR::PathExpression>();
314-
if (pl && pr) {
315-
return pl->path->name == pr->path->name &&
316-
pl->path->absolute == pr->path->absolute;
317-
}
318-
auto al = left->to<IR::ArrayIndex>();
319-
auto ar = right->to<IR::ArrayIndex>();
320-
if (al && ar) {
321-
return equiv(al->left, ar->left) && equiv(al->right, ar->right);
322-
}
323-
auto tl = left->to<IR::Operation_Ternary>();
324-
auto tr = right->to<IR::Operation_Ternary>();
325-
if (tl && tr) {
326-
bool check = equiv(tl->e0, tr->e0) && equiv(tl->e1, tr->e1) && equiv(tl->e2, tr->e2) &&
327-
typeid(*tl) == typeid(*tr);
328-
return check;
329-
}
330-
331-
// Compare binary operations (array indices)
332-
auto bl = left->to<IR::Operation_Binary>();
333-
auto br = right->to<IR::Operation_Binary>();
334-
if (bl && br) {
335-
return equiv(bl->left, br->left) && equiv(bl->right, br->right) &&
336-
typeid(*bl) == typeid(*br);
337-
}
338-
// Compare packet header/metadata fields
339-
auto ml = left->to<IR::Member>();
340-
auto mr = right->to<IR::Member>();
341-
if (ml && mr) {
342-
return ml->member == mr->member && equiv(ml->expr, mr->expr);
343-
}
344-
// Compare unary operations (can be used inside array indices)
345-
auto ul = left->to<IR::Operation_Unary>();
346-
auto ur = right->to<IR::Operation_Unary>();
347-
if (ul && ur) {
348-
return equiv(ul->expr, ur->expr) &&
349-
typeid(*ul) == typeid(*ur);
350-
}
351-
352-
// Compare value and base of the constants, but do not include the type
353-
auto cl = left->to<IR::Constant>();
354-
auto cr = right->to<IR::Constant>();
355-
if (cl && cr) {
356-
return cl->base == cr->base && cl->value == cr->value &&
357-
typeid(*cl) == typeid(*cr);
358-
}
359-
360-
// Compare literals (strings, booleans and integers)
361-
if (*left == *right)
362-
return true;
363-
return false;
364-
}
365-
366305
IR::AssignmentStatement *DoLocalCopyPropagation::postorder(IR::AssignmentStatement *as) {
367-
if (equiv(as->left, as->right)) {
306+
if (as->left->equiv(*as->right)) {
368307
LOG3(" removing noop assignment " << *as);
369308
return nullptr; }
370309
// FIXME -- if as->right is an uninitialized value, we could legally eliminate this
@@ -560,7 +499,7 @@ void DoLocalCopyPropagation::apply_table(DoLocalCopyPropagation::TableInfo *tbl)
560499
forOverlapAvail(key, [key, tbl, this](VarInfo *var) {
561500
if (var->val && lvalue_out(var->val)->is<IR::PathExpression>()) {
562501
if (tbl->apply_count > 1 &&
563-
(!tbl->key_remap.count(key) || !equiv(tbl->key_remap.at(key), var->val))) {
502+
(!tbl->key_remap.count(key) || !tbl->key_remap.at(key)->equiv(*var->val))) {
564503
/* FIXME -- need deep expr comparison here, not shallow */
565504
LOG3(" different values used in different applies for key " << key);
566505
tbl->key_remap.erase(key);

midend/local_copyprop.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ class DoLocalCopyPropagation : public ControlFlowVisitor, Transform, P4WriteCont
100100
void apply_function(FuncInfo *tbl);
101101
IR::P4Table *preorder(IR::P4Table *) override;
102102
IR::P4Table *postorder(IR::P4Table *) override;
103-
bool equiv(const IR::Expression *left, const IR::Expression *right);
104103
class ElimDead;
105104
class RewriteTableKeys;
106105

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ set (GTEST_UNITTEST_SOURCES
4242
gtest/diagnostics.cpp
4343
gtest/dumpjson.cpp
4444
gtest/enumerator_test.cpp
45+
gtest/equiv_test.cpp
4546
gtest/exception_test.cpp
4647
gtest/expr_uses_test.cpp
4748
gtest/format_test.cpp

test/gtest/equiv_test.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
Copyright 2013-present Barefoot Networks, Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
#include "gtest/gtest.h"
18+
#include "ir/ir.h"
19+
#include "ir/visitor.h"
20+
#include "lib/exceptions.h"
21+
22+
TEST(IR, Equiv) {
23+
auto *t = IR::Type::Bits::get(16);
24+
auto *a1 = new IR::Constant(t, 10);
25+
auto *a2 = new IR::Constant(t, 10);
26+
auto *b = new IR::Constant(IR::Type::Bits::get(10), 10);
27+
auto *c = new IR::Constant(t, 20);
28+
auto *d1 = new IR::PathExpression("d");
29+
auto *d2 = new IR::PathExpression("d");
30+
auto *e = new IR::PathExpression("e");
31+
auto *d1m = new IR::Member(d1, "m");
32+
auto *d2m = new IR::Member(d2, "m");
33+
auto *em = new IR::Member(e, "m");
34+
auto *d1f = new IR::Member(d1, "f");
35+
36+
EXPECT_TRUE(a1->equiv(*a2));
37+
EXPECT_TRUE(d1->equiv(*d2));
38+
EXPECT_FALSE(a1->equiv(*b));
39+
EXPECT_FALSE(a1->equiv(*c));
40+
EXPECT_FALSE(a1->equiv(*e));
41+
EXPECT_FALSE(d1->equiv(*e));
42+
EXPECT_TRUE(d1m->equiv(*d2m));
43+
EXPECT_FALSE(d1m->equiv(*em));
44+
EXPECT_FALSE(d1m->equiv(*d1f));
45+
46+
auto *call1 = new IR::MethodCallExpression(d1m, { a1, d1 });
47+
auto *call2 = new IR::MethodCallExpression(d2m, { a2, d2 });
48+
auto *call3 = new IR::MethodCallExpression(d1m, { b, d1 });
49+
50+
EXPECT_TRUE(call1->equiv(*call2));
51+
EXPECT_FALSE(call1->equiv(*call3));
52+
53+
auto *list1 = new IR::ListExpression({ a1, b, d1 });
54+
auto *list2 = new IR::ListExpression({ a1, b, d2 });
55+
auto *list3 = new IR::ListExpression({ a1, b, e });
56+
57+
EXPECT_TRUE(list1->equiv(*list2));
58+
EXPECT_FALSE(list1->equiv(*list3));
59+
60+
auto *pr1 = new IR::V1Program;
61+
auto *pr2 = pr1->clone();
62+
pr1->add("a", a1);
63+
pr1->add("b", b);
64+
pr1->add("call", call1);
65+
pr2->add("a", a2);
66+
pr2->add("b", b);
67+
pr2->add("call", call2);
68+
EXPECT_TRUE(pr1->equiv(*pr2));
69+
pr1->add("lista", list1);
70+
pr2->add("listb", list1);
71+
EXPECT_FALSE(pr1->equiv(*pr2));
72+
}

tools/ir-generator/irclass.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,10 @@ class IrField : public IrElement {
129129
: IrElement(info), type(type), name(name), initializer(init), nullOK(flags & NullOK),
130130
optional(flags & Optional), isInline(flags & Inline), isStatic(flags & Static),
131131
isConst(flags & Const) {}
132-
IrField(const Type *type, cstring name, cstring init = cstring())
133-
: type(type), name(name), initializer(init) {}
132+
IrField(const Type *type, cstring name, cstring init = cstring(), int flags = 0)
133+
: IrField(Util::SourceInfo(), type, name, init, flags) {}
134+
IrField(const Type *type, cstring name, int flags)
135+
: IrField(Util::SourceInfo(), type, name, cstring(), flags) {}
134136
void generate(std::ostream &out, bool asField) const;
135137
void generate_hdr(std::ostream &out) const override { generate(out, true); }
136138
void generate_impl(std::ostream &) const override;

0 commit comments

Comments
 (0)