Skip to content

Commit 7f8f5f7

Browse files
committed
Use LHS components as synthesis components
1 parent 2844815 commit 7f8f5f7

File tree

4 files changed

+130
-46
lines changed

4 files changed

+130
-46
lines changed

include/souper/Infer/InstSynthesis.h

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,13 @@ typedef std::pair<unsigned, unsigned> LocVar;
7070
typedef std::pair<LocVar, Inst *> LocInst;
7171

7272
/// A component is a fixed-width instruction kind
73+
/// or created from Origin
7374
struct Component {
7475
Inst::Kind Kind;
7576
unsigned Width;
7677
std::vector<unsigned> OpWidths;
78+
Inst *Origin;
79+
std::vector<Inst *> OriginOps;
7780
};
7881

7982
/// Unsupported components kinds
@@ -94,35 +97,35 @@ static const std::set<Inst::Kind> UnsupportedCompKinds = {
9497
/// a component of that width is instantiated.
9598
/// Again, note that constants are treated as ordinary inputs
9699
static const std::vector<Component> CompLibrary = {
97-
Component{Inst::Add, 0, {0,0}},
98-
Component{Inst::Sub, 0, {0,0}},
99-
Component{Inst::Mul, 0, {0,0}},
100-
Component{Inst::UDiv, 0, {0,0}},
101-
Component{Inst::SDiv, 0, {0,0}},
102-
Component{Inst::UDivExact, 0, {0,0}},
103-
Component{Inst::SDivExact, 0, {0,0}},
104-
Component{Inst::URem, 0, {0,0}},
105-
Component{Inst::SRem, 0, {0,0}},
106-
Component{Inst::And, 0, {0,0}},
107-
Component{Inst::Or, 0, {0,0}},
108-
Component{Inst::Xor, 0, {0,0}},
109-
Component{Inst::Shl, 0, {0,0}},
110-
Component{Inst::LShr, 0, {0,0}},
111-
Component{Inst::LShrExact, 0, {0,0}},
112-
Component{Inst::AShr, 0, {0,0}},
113-
Component{Inst::AShrExact, 0, {0,0}},
114-
Component{Inst::Select, 0, {1,0,0}},
115-
Component{Inst::Eq, 1, {0,0}},
116-
Component{Inst::Ne, 1, {0,0}},
117-
Component{Inst::Ult, 1, {0,0}},
118-
Component{Inst::Slt, 1, {0,0}},
119-
Component{Inst::Ule, 1, {0,0}},
120-
Component{Inst::Sle, 1, {0,0}},
100+
Component{Inst::Add, 0, {0,0}, 0, {}},
101+
Component{Inst::Sub, 0, {0,0}, 0, {}},
102+
Component{Inst::Mul, 0, {0,0}, 0, {}},
103+
Component{Inst::UDiv, 0, {0,0}, 0, {}},
104+
Component{Inst::SDiv, 0, {0,0}, 0, {}},
105+
Component{Inst::UDivExact, 0, {0,0}, 0, {}},
106+
Component{Inst::SDivExact, 0, {0,0}, 0, {}},
107+
Component{Inst::URem, 0, {0,0}, 0, {}},
108+
Component{Inst::SRem, 0, {0,0}, 0, {}},
109+
Component{Inst::And, 0, {0,0}, 0, {}},
110+
Component{Inst::Or, 0, {0,0}, 0, {}},
111+
Component{Inst::Xor, 0, {0,0}, 0, {}},
112+
Component{Inst::Shl, 0, {0,0}, 0, {}},
113+
Component{Inst::LShr, 0, {0,0}, 0, {}},
114+
Component{Inst::LShrExact, 0, {0,0}, 0, {}},
115+
Component{Inst::AShr, 0, {0,0}, 0, {}},
116+
Component{Inst::AShrExact, 0, {0,0}, 0, {}},
117+
Component{Inst::Select, 0, {1,0,0}, 0, {}},
118+
Component{Inst::Eq, 1, {0,0}, 0, {}},
119+
Component{Inst::Ne, 1, {0,0}, 0, {}},
120+
Component{Inst::Ult, 1, {0,0}, 0, {}},
121+
Component{Inst::Slt, 1, {0,0}, 0, {}},
122+
Component{Inst::Ule, 1, {0,0}, 0, {}},
123+
Component{Inst::Sle, 1, {0,0}, 0, {}},
121124
//
122-
Component{Inst::CtPop, 0, {0}},
123-
Component{Inst::BSwap, 0, {0}},
124-
Component{Inst::Cttz, 0, {0}},
125-
Component{Inst::Ctlz, 0, {0}}
125+
Component{Inst::CtPop, 0, {0}, 0, {}},
126+
Component{Inst::BSwap, 0, {0}, 0, {}},
127+
Component{Inst::Cttz, 0, {0}, 0, {}},
128+
Component{Inst::Ctlz, 0, {0}, 0, {}}
126129
};
127130

128131
class InstSynthesis {
@@ -132,13 +135,15 @@ class InstSynthesis {
132135
const BlockPCs &BPCs,
133136
const std::vector<InstMapping> &PCs,
134137
Inst *TargetLHS, Inst *&RHS,
138+
const std::vector<Inst *> &LHSComps,
135139
InstContext &IC, unsigned Timeout);
136140

137141
private:
138142
/// Local references
139143
SMTLIBSolver *LSMTSolver;
140144
const BlockPCs *LBPCs;
141145
const std::vector<InstMapping> *LPCs;
146+
const std::vector<Inst *> *LLHSComps;
142147
InstContext *LIC;
143148
unsigned LTimeout;
144149

@@ -291,6 +296,7 @@ class InstSynthesis {
291296

292297
/// Helper functions
293298
void filterFixedWidthIntrinsicComps();
299+
Component getCompFromInst(Inst *);
294300
void getInputVars(Inst *I, std::vector<Inst *> &InputVars);
295301
std::string getLocVarStr(const LocVar &Loc, const std::string Prefix="");
296302
LocVar getLocVarFromStr(const std::string &Str);

lib/Extractor/Solver.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,11 @@ class BaseSolver : public Solver {
206206
}
207207

208208
if (InferInsts && SMTSolver->supportsModels()) {
209+
std::vector<Inst *> LHSComps;
210+
findCands(LHS, LHSComps, IC, MaxNops);
209211
InstSynthesis IS;
210-
EC = IS.synthesize(SMTSolver.get(), BPCs, PCs, LHS, RHS, IC, Timeout);
212+
EC = IS.synthesize(SMTSolver.get(), BPCs, PCs, LHS, RHS,
213+
LHSComps, IC, Timeout);
211214
if (EC || RHS)
212215
return EC;
213216
}

lib/Infer/InstSynthesis.cpp

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,15 @@ std::error_code InstSynthesis::synthesize(SMTLIBSolver *SMTSolver,
5959
const BlockPCs &BPCs,
6060
const std::vector<InstMapping> &PCs,
6161
Inst *TargetLHS, Inst *&RHS,
62+
const std::vector<Inst *> &LHSComps,
6263
InstContext &IC, unsigned Timeout) {
6364
std::error_code EC;
6465

6566
// init local refs
6667
LSMTSolver = SMTSolver;
6768
LBPCs = &BPCs;
6869
LPCs = &PCs;
70+
LLHSComps = &LHSComps;
6971
LIC = &IC;
7072
LTimeout = Timeout;
7173

@@ -91,7 +93,7 @@ std::error_code InstSynthesis::synthesize(SMTLIBSolver *SMTSolver,
9193

9294
if (DebugLevel > 0) {
9395
llvm::outs() << "; starting synthesis for LHS\n";
94-
PrintReplacementLHS(llvm::outs(), BPCs, PCs, LHS, Context);
96+
PrintReplacementLHS(llvm::outs(), BPCs, PCs, LHS, Context, true);
9597
if (DebugLevel > 2)
9698
printInitInfo();
9799
}
@@ -322,7 +324,7 @@ void InstSynthesis::setCompLibrary() {
322324
for (auto KindStr : splitString(CmdUserCompKinds.c_str())) {
323325
Inst::Kind K = Inst::getKind(KindStr);
324326
if (KindStr == Inst::getKindName(Inst::Const)) // Special case
325-
InitConstComps.push_back(Component{Inst::Const, 0, {}});
327+
InitConstComps.push_back(Component{Inst::Const, 0, {}, 0, {}});
326328
else if (K == Inst::ZExt || K == Inst::SExt || K == Inst::Trunc)
327329
report_fatal_error("don't use zext/sext/trunc explicitly");
328330
else if (K == Inst::None)
@@ -338,13 +340,13 @@ void InstSynthesis::setCompLibrary() {
338340
InitComps.push_back(Comp);
339341
} else {
340342
InitComps = CompLibrary;
341-
InitConstComps.push_back(Component{Inst::Const, 0, {}});
343+
InitConstComps.push_back(Component{Inst::Const, 0, {}, 0, {}});
342344
}
343345
for (auto const &In : Inputs) {
344346
if (In->Width == DefaultWidth)
345347
continue;
346-
Comps.push_back(Component{Inst::ZExt, DefaultWidth, {In->Width}});
347-
Comps.push_back(Component{Inst::SExt, DefaultWidth, {In->Width}});
348+
Comps.push_back(Component{Inst::ZExt, DefaultWidth, {In->Width}, 0, {}});
349+
Comps.push_back(Component{Inst::SExt, DefaultWidth, {In->Width}, 0, {}});
348350
}
349351
// Second, for each input/constant create a component of DefaultWidth
350352
for (auto &Comp : InitComps) {
@@ -362,7 +364,23 @@ void InstSynthesis::setCompLibrary() {
362364
}
363365
// Third, create one trunc comp to match the output width if necessary
364366
if (LHS->Width < DefaultWidth)
365-
Comps.push_back(Component{Inst::Trunc, LHS->Width, {DefaultWidth}});
367+
Comps.push_back(Component{Inst::Trunc, LHS->Width, {DefaultWidth}, 0, {}});
368+
// Finally, add LHS components (if provided) directly to Comps,
369+
// their widths are already initialized.
370+
for (auto I : *LLHSComps) {
371+
// No support for the following Insts
372+
switch (I->K) {
373+
case Inst::Phi:
374+
// TODO: Why do we get these as candidates?!
375+
case Inst::Var:
376+
case Inst::Const:
377+
case Inst::UntypedConst:
378+
continue;
379+
default:
380+
break;
381+
}
382+
Comps.push_back(getCompFromInst(I));
383+
}
366384
}
367385

368386
void InstSynthesis::initInputVars(InstContext &IC) {
@@ -438,10 +456,11 @@ void InstSynthesis::filterFixedWidthIntrinsicComps() {
438456

439457
void InstSynthesis::initComponents(InstContext &IC) {
440458
for (unsigned J = 0; J < Comps.size(); ++J) {
441-
auto const &Comp = Comps[J];
459+
auto &Comp = Comps[J];
442460
std::string LocVarStr;
443461
// First, init component inputs
444462
std::vector<Inst *> CompOps;
463+
std::map<Inst *, Inst *> OpsReplacements;
445464
std::vector<LocVar> OpsLocVar;
446465
for (unsigned K = 0; K < Comp.OpWidths.size(); ++K) {
447466
LocVar In = std::make_pair(J+1, K+1);
@@ -464,6 +483,11 @@ void InstSynthesis::initComponents(InstContext &IC) {
464483
CompInstMap[In] = OpInst;
465484
CompOps.push_back(OpInst);
466485
OpsLocVar.push_back(In);
486+
// Update OpsReplacements
487+
if (Comp.Origin) {
488+
assert(Comp.OriginOps.size());
489+
OpsReplacements.insert(std::make_pair(Comp.OriginOps[K], OpInst));
490+
}
467491
}
468492
// Store all input locations
469493
CompOpLocVars.push_back(OpsLocVar);
@@ -479,13 +503,23 @@ void InstSynthesis::initComponents(InstContext &IC) {
479503
// Third, instantiate the component (aka Inst)
480504
assert(Comp.Width && "comp width not set");
481505
Inst *CompInst;
482-
if (Comp.Kind == Inst::Select) {
483-
Inst *C = IC.getInst(Inst::Trunc, 1, {CompOps[0]});
484-
CompInst = IC.getInst(Comp.Kind, Comp.Width, {C, CompOps[1], CompOps[2]});
485-
} else {
486-
CompInst = IC.getInst(Comp.Kind, Comp.Width, CompOps);
506+
if (Comp.Origin) {
507+
assert(Comp.OriginOps.size() == CompOps.size());
508+
CompInst = getInstCopy(Comp.Origin, *LIC, OpsReplacements);
487509
if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc)
488510
CompInst = IC.getInst(Inst::ZExt, DefaultWidth, {CompInst});
511+
// Update LHS component
512+
Comp.Origin = CompInst;
513+
Comp.OriginOps = CompOps;
514+
} else {
515+
if (Comp.Kind == Inst::Select) {
516+
Inst *C = IC.getInst(Inst::Trunc, 1, {CompOps[0]});
517+
CompInst = IC.getInst(Comp.Kind, Comp.Width, {C, CompOps[1], CompOps[2]});
518+
} else {
519+
CompInst = IC.getInst(Comp.Kind, Comp.Width, CompOps);
520+
if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc)
521+
CompInst = IC.getInst(Inst::ZExt, DefaultWidth, {CompInst});
522+
}
489523
}
490524
// Update CompInstMap map with concrete Inst
491525
CompInstMap[Out] = CompInst;
@@ -517,12 +551,14 @@ void InstSynthesis::printInitInfo() {
517551
llvm::outs() << "N: " << N << ", M: " << M << "\n";
518552
llvm::outs() << "default width: " << DefaultWidth << "\n";
519553
llvm::outs() << "output width: " << LHS->Width << "\n";
520-
llvm::outs() << "component library: ";
554+
llvm::outs() << "component library: " << Comps.size() << "\n";
521555
for (auto const &Comp : Comps) {
522556
llvm::outs() << Inst::getKindName(Comp.Kind) << " (" << Comp.Width << ", { ";
523557
for (auto const &Width : Comp.OpWidths)
524558
llvm::outs() << Width << " ";
525-
llvm::outs() << "}); ";
559+
llvm::outs() << "})\n";
560+
if (Comp.Origin)
561+
PrintReplacementRHS(llvm::outs(), Comp.Origin, Context, true);
526562
}
527563
if (Comps.size())
528564
llvm::outs() << "\n";
@@ -980,15 +1016,28 @@ Inst *InstSynthesis::createInstFromWiring(
9801016
llvm::outs() << "- creating inst " << Inst::getKindName(Comp.Kind)
9811017
<< ", width " << Comp.Width << "\n";
9821018
llvm::outs() << "before junk removal:\n";
983-
PrintReplacementRHS(llvm::outs(), IC.getInst(Comp.Kind, Comp.Width, Ops),
984-
Context);
1019+
if (Comp.Origin)
1020+
PrintReplacementRHS(llvm::outs(), Comp.Origin, Context);
1021+
else
1022+
PrintReplacementRHS(llvm::outs(), IC.getInst(Comp.Kind, Comp.Width, Ops),
1023+
Context);
9851024
}
9861025
// Sanity checks
9871026
if (Ops.size() == 2 && Ops[0]->K == Inst::Const && Ops[1]->K == Inst::Const)
9881027
report_fatal_error("inst operands are constants!");
9891028
assert(Comp.Width == 1 || Comp.Width == DefaultWidth ||
9901029
Comp.Width == LHS->Width);
991-
// Create instruction
1030+
// Instruction is a LHS component
1031+
if (Comp.Origin) {
1032+
assert(Comp.OriginOps.size() == Ops.size());
1033+
std::map<Inst *, Inst *> OpsReplacements;
1034+
for (unsigned J = 0; J < Ops.size(); ++J)
1035+
OpsReplacements.insert(std::make_pair(Comp.OriginOps[J], Ops[J]));
1036+
Inst *Copy = getInstCopy(Comp.Origin, *LIC, OpsReplacements);
1037+
// Update ops
1038+
Ops = Copy->Ops;
1039+
}
1040+
// Create instruction from a component
9921041
if (Comp.Kind == Inst::Select) {
9931042
Ops[0] = IC.getInst(Inst::Trunc, 1, {Ops[0]});
9941043
return createCleanInst(Comp.Kind, Comp.Width, Ops, IC);
@@ -1214,6 +1263,18 @@ Inst *InstSynthesis::createCleanInst(Inst::Kind Kind, unsigned Width,
12141263
return IC.getInst(Kind, Width, Ops);
12151264
}
12161265

1266+
Component InstSynthesis::getCompFromInst(Inst *I) {
1267+
std::vector<Inst *> IV;
1268+
getInputVars(I, IV);
1269+
sort(IV.begin(), IV.end());
1270+
IV.erase(unique(IV.begin(), IV.end()), IV.end());
1271+
std::vector<unsigned> OpWidths;
1272+
for (auto In : IV)
1273+
OpWidths.push_back(In->Width);
1274+
1275+
return Component{I->K, I->Width, OpWidths, I, IV};
1276+
}
1277+
12171278
void InstSynthesis::getInputVars(Inst *I, std::vector<Inst *> &InputVars) {
12181279
if (I->K == Inst::Var)
12191280
InputVars.push_back(I);

test/Infer/four-adds.opt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
; REQUIRES: solver, solver-model
2+
3+
; -souper-synthesis-comps=const is just a hack to avoid the initialization of the whole component library
4+
; RUN: %souper-check %solver -infer-rhs -souper-infer-inst -souper-synthesis-comps=const -souper-synthesis-ignore-cost %s > %t1
5+
; RUN: %FileCheck %s < %t1
6+
7+
; CHECK: result %4
8+
9+
%0:i32 = var
10+
%1:i32 = add 1:i32, %0
11+
%2:i32 = add 1:i32, %1
12+
%3:i32 = add 1:i32, %2
13+
%4:i32 = add 1:i32, %3
14+
infer %4

0 commit comments

Comments
 (0)