Skip to content

Commit

Permalink
Refine tail call conditions for musttail
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniofrighetto committed Oct 10, 2024
1 parent 338f681 commit 0c8b3cf
Show file tree
Hide file tree
Showing 11 changed files with 225 additions and 103 deletions.
78 changes: 78 additions & 0 deletions ir/attrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// Distributed under the MIT license that can be found in the LICENSE file.

#include "ir/attrs.h"
#include "ir/function.h"
#include "ir/globals.h"
#include "ir/instr.h"
#include "ir/memory.h"
#include "ir/state.h"
#include "ir/state_value.h"
Expand Down Expand Up @@ -619,4 +621,80 @@ ostream& operator<<(std::ostream &os, FpExceptionMode ex) {
return os << str;
}

ostream& operator<<(std::ostream &os, const TailCallInfo &tci) {
const char *str = nullptr;
switch (tci.type) {
case TailCallInfo::None: str = ""; break;
case TailCallInfo::Tail: str = "tail "; break;
case TailCallInfo::MustTail: str = "musttail "; break;
}
return os << str;
}

void TailCallInfo::checkTailCall(const Instr &i, State &s) const {
bool preconditions_OK = true;
assert(type != TailCallInfo::None);

auto *callee = dynamic_cast<const FnCall *>(&i);
if (callee) {
for (const auto &[arg, attrs] : callee->getArgs()) {
bool callee_has_byval = attrs.has(ParamAttrs::ByVal);
if (dynamic_cast<const Alloc *>(arg) && !callee_has_byval) {
preconditions_OK = false;
break;
}
if (auto *input = dynamic_cast<const Input *>(arg)) {
bool caller_has_byval = input->hasAttribute(ParamAttrs::ByVal);
if (callee_has_byval != caller_has_byval) {
preconditions_OK = false;
break;
}
}
}
} else {
// Handling memcpy / memcmp et alia.
for (const auto &op : i.operands()) {
if (dynamic_cast<const Alloc *>(op)) {
preconditions_OK = false;
break;
}
}
}

if (callee && type == TailCallInfo::MustTail) {
bool callee_is_vararg = callee->getVarArgIdx() != -1u;
bool caller_is_vararg = s.getFn().isVarArgs();
if (!has_same_calling_convention || (callee_is_vararg && !caller_is_vararg))
preconditions_OK = false;
}

if (preconditions_OK && type == TailCallInfo::MustTail) {
bool found = false;
const auto &instrs = s.getFn().bbOf(i).instrs();
auto it = instrs.begin();
for (auto e = instrs.end(); it != e; ++it) {
if (&*it == &i) {
found = true;
break;
}
}
assert(found);

++it;
auto &next_instr = *it;
if (auto *ret = dynamic_cast<const Return *>(&next_instr)) {
if (ret->getType().isVoid() && i.getType().isVoid())
return;
auto *ret_val = ret->operands()[0];
if (ret_val == &i)
return;
}
}

if (!preconditions_OK) {
// Preconditions unsatifisfied or refinement for musttail failed, hence UB.
s.addUB(expr(false));
}
}

}
19 changes: 16 additions & 3 deletions ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class State;
struct StateValue;
class Type;
class Value;
class Instr;

class MemoryAccess final {
unsigned val = 0;
Expand Down Expand Up @@ -122,7 +123,6 @@ class FnAttrs final {
std::optional<FPDenormalAttrs> fp_denormal32;
unsigned bits;
uint8_t allockind = 0;
bool is_tailcall = false;

public:
enum Attribute { None = 0, NNaN = 1 << 0, NoReturn = 1 << 1,
Expand Down Expand Up @@ -155,8 +155,6 @@ class FnAttrs final {
void add(AllocKind k) { allockind |= (uint8_t)k; }
bool has(AllocKind k) const { return allockind & (uint8_t)k; }
bool isAlloc() const { return allockind != 0 || has(AllocSize); }
void setTailCallSite(bool is_tc) { is_tailcall = is_tc; }
bool isTailCall() const { return is_tailcall; }

void inferImpliedAttributes();

Expand Down Expand Up @@ -219,4 +217,19 @@ struct FpExceptionMode final {

smt::expr isfpclass(const smt::expr &v, const Type &ty, uint16_t mask);

struct TailCallInfo final {
enum TailCallType { None, Tail, MustTail } type;
// Determine if callee and caller have the same calling convention.
bool has_same_calling_convention;

TailCallInfo() : type(None), has_same_calling_convention(false) {}
TailCallInfo(TailCallType type, bool has_same_calling_convention)
: type(type), has_same_calling_convention(has_same_calling_convention) {}

TailCallType getType() const { return type; }
bool isTailCall() const { return type != None; }
void checkTailCall(const Instr &i, State &s) const;
friend std::ostream& operator<<(std::ostream &os, const TailCallInfo &tci);
};

}
95 changes: 35 additions & 60 deletions ir/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2295,8 +2295,7 @@ void FnCall::print(ostream &os) const {
if (!isVoid())
os << getName() << " = ";

os << (getAttributes().isTailCall() ? "tail " : "")
<< "call " << print_type(getType())
os << tci << "call " << print_type(getType())
<< (fnptr ? fnptr->getName() : fnName) << '(';

bool first = true;
Expand All @@ -2314,31 +2313,6 @@ void FnCall::print(ostream &os) const {
os << ')' << attrs;
}

static void check_tailcall(const Instr &i, State &s) {
bool found = false;
const auto &instrs = s.getFn().bbOf(i).instrs();
auto it = instrs.begin();
for (auto e = instrs.end(); it != e; ++it) {
if (&*it == &i) {
found = true;
break;
}
}
assert(found);

++it;
auto &next_instr = *it;
if (auto *ret = dynamic_cast<const Return *>(&next_instr)) {
if (ret->getType().isVoid() && i.getType().isVoid())
return;
auto *ret_val = ret->operands()[0];
if (ret_val == &i)
return;
}

s.addUB(expr(false));
}

static void check_can_load(State &s, const expr &p0) {
auto &attrs = s.getFn().getFnAttrs();
if (attrs.mem.canReadAnything())
Expand Down Expand Up @@ -2519,8 +2493,9 @@ StateValue FnCall::toSMT(State &s) const {
!attrs.has(FnAttrs::WillReturn))
s.addGuardableUB(expr(false));

if (getAttributes().isTailCall())
check_tailcall(*this, s);
const auto &attrs = getAttributes();
if (tci.isTailCall())
tci.checkTailCall(*this, s);

auto get_alloc_ptr = [&]() -> Value& {
for (auto &[arg, flags] : args) {
Expand Down Expand Up @@ -2625,6 +2600,7 @@ unique_ptr<Instr> FnCall::dup(Function &f, const string &suffix) const {
FnAttrs(attrs), fnptr, var_arg_idx);
r->args = args;
r->approx = approx;
r->tci = TailCallInfo(tci);
return r;
}

Expand Down Expand Up @@ -4152,8 +4128,8 @@ void Memset::rauw(const Value &what, Value &with) {
}

void Memset::print(ostream &os) const {
os << (isTailCall() ? "tail " : "") << "memset " << *ptr
<< " align " << align << ", " << *val << ", " << *bytes;
os << tci << "memset " << *ptr << " align " << align << ", " << *val << ", "
<< *bytes;
}

StateValue Memset::toSMT(State &s) const {
Expand All @@ -4173,8 +4149,8 @@ StateValue Memset::toSMT(State &s) const {
vptr = sv_ptr.value;
}
check_can_store(s, vptr);
if (isTailCall())
check_tailcall(*this, s);
if (tci.isTailCall())
tci.checkTailCall(*this, s);

s.getMemory().memset(vptr, s[*val].zextOrTrunc(8), vbytes, align,
s.getUndefVars());
Expand All @@ -4188,18 +4164,18 @@ expr Memset::getTypeConstraints(const Function &f) const {
}

unique_ptr<Instr> Memset::dup(Function &f, const string &suffix) const {
return make_unique<Memset>(*ptr, *val, *bytes, align, is_tailcall);
return make_unique<Memset>(*ptr, *val, *bytes, align, TailCallInfo(tci));
}


DEFINE_AS_RETZEROALIGN(MemsetPattern, getMaxAllocSize);
DEFINE_AS_RETZERO(MemsetPattern, getMaxGEPOffset);

MemsetPattern::MemsetPattern(Value &ptr, Value &pattern, Value &bytes,
unsigned pattern_length, bool is_tailcall)
: MemInstr(Type::voidTy, "memset_pattern" + to_string(pattern_length)),
ptr(&ptr), pattern(&pattern), bytes(&bytes),
pattern_length(pattern_length), is_tailcall(is_tailcall) {}
unsigned pattern_length, TailCallInfo tci)
: MemInstr(Type::voidTy, "memset_pattern" + to_string(pattern_length)),
ptr(&ptr), pattern(&pattern), bytes(&bytes),
pattern_length(pattern_length), tci(tci) {}

uint64_t MemsetPattern::getMaxAccessSize() const {
return getIntOr(*bytes, UINT64_MAX);
Expand Down Expand Up @@ -4227,8 +4203,7 @@ void MemsetPattern::rauw(const Value &what, Value &with) {
}

void MemsetPattern::print(ostream &os) const {
os << getName() << ' ' << (isTailCall() ? "tail " : "")
<< *ptr << ", " << *pattern << ", " << *bytes;
os << getName() << ' ' << tci << *ptr << ", " << *pattern << ", " << *bytes;
}

StateValue MemsetPattern::toSMT(State &s) const {
Expand All @@ -4237,8 +4212,8 @@ StateValue MemsetPattern::toSMT(State &s) const {
auto &vbytes = s.getAndAddPoisonUB(*bytes, true).value;
check_can_store(s, vptr);
check_can_load(s, vpattern);
if (isTailCall())
check_tailcall(*this, s);
if (tci.isTailCall())
tci.checkTailCall(*this, s);

s.getMemory().memset_pattern(vptr, vpattern, vbytes, pattern_length);
return {};
Expand All @@ -4251,7 +4226,8 @@ expr MemsetPattern::getTypeConstraints(const Function &f) const {
}

unique_ptr<Instr> MemsetPattern::dup(Function &f, const string &suffix) const {
return make_unique<MemsetPattern>(*ptr, *pattern, *bytes, pattern_length, is_tailcall);
return make_unique<MemsetPattern>(*ptr, *pattern, *bytes, pattern_length,
TailCallInfo(tci));
}


Expand Down Expand Up @@ -4333,9 +4309,8 @@ void Memcpy::rauw(const Value &what, Value &with) {
}

void Memcpy::print(ostream &os) const {
os << (isTailCall() ? "tail " : "") << (move ? "memmove " : "memcpy ")
<< *dst << " align " << align_dst << ", "
<< *src << " align " << align_src << ", " << *bytes;
os << tci << (move ? "memmove " : "memcpy ") << *dst << " align " << align_dst
<< ", " << *src << " align " << align_src << ", " << *bytes;
}

StateValue Memcpy::toSMT(State &s) const {
Expand Down Expand Up @@ -4367,8 +4342,8 @@ StateValue Memcpy::toSMT(State &s) const {

check_can_load(s, vsrc);
check_can_store(s, vdst);
if (isTailCall())
check_tailcall(*this, s);
if (tci.isTailCall())
tci.checkTailCall(*this, s);

s.getMemory().memcpy(vdst, vsrc, vbytes, align_dst, align_src, move);
return {};
Expand All @@ -4381,7 +4356,8 @@ expr Memcpy::getTypeConstraints(const Function &f) const {
}

unique_ptr<Instr> Memcpy::dup(Function &f, const string &suffix) const {
return make_unique<Memcpy>(*dst, *src, *bytes, align_dst, align_src, move, is_tailcall);
return make_unique<Memcpy>(*dst, *src, *bytes, align_dst, align_src, move,
TailCallInfo(tci));
}


Expand Down Expand Up @@ -4414,9 +4390,8 @@ void Memcmp::rauw(const Value &what, Value &with) {
}

void Memcmp::print(ostream &os) const {
os << getName() << " = " << (isTailCall() ? "tail " : "")
<< (is_bcmp ? "bcmp " : "memcmp ") << *ptr1 << ", "
<< *ptr2 << ", " << *num;
os << getName() << " = " << tci << (is_bcmp ? "bcmp " : "memcmp ") << *ptr1
<< ", " << *ptr2 << ", " << *num;
}

StateValue Memcmp::toSMT(State &s) const {
Expand All @@ -4427,8 +4402,8 @@ StateValue Memcmp::toSMT(State &s) const {

check_can_load(s, vptr1);
check_can_load(s, vptr2);
if (isTailCall())
check_tailcall(*this, s);
if (tci.isTailCall())
tci.checkTailCall(*this, s);

Pointer p1(s.getMemory(), vptr1), p2(s.getMemory(), vptr2);
// memcmp can be optimized to load & icmps, and it requires this
Expand Down Expand Up @@ -4492,7 +4467,7 @@ expr Memcmp::getTypeConstraints(const Function &f) const {

unique_ptr<Instr> Memcmp::dup(Function &f, const string &suffix) const {
return make_unique<Memcmp>(getType(), getName() + suffix, *ptr1, *ptr2, *num,
is_bcmp, is_tailcall);
is_bcmp, TailCallInfo(tci));
}


Expand Down Expand Up @@ -4520,15 +4495,14 @@ void Strlen::rauw(const Value &what, Value &with) {
}

void Strlen::print(ostream &os) const {
os << getName() << " = " << (isTailCall() ? "tail " : "")
<< "strlen " << *ptr;
os << getName() << " = " << tci << "strlen " << *ptr;
}

StateValue Strlen::toSMT(State &s) const {
auto &eptr = s.getWellDefinedPtr(*ptr);
check_can_load(s, eptr);
if (isTailCall())
check_tailcall(*this, s);
if (tci.isTailCall())
tci.checkTailCall(*this, s);

Pointer p(s.getMemory(), eptr);
Type &ty = getType();
Expand All @@ -4555,7 +4529,8 @@ expr Strlen::getTypeConstraints(const Function &f) const {
}

unique_ptr<Instr> Strlen::dup(Function &f, const string &suffix) const {
return make_unique<Strlen>(getType(), getName() + suffix, *ptr, is_tailcall);
return make_unique<Strlen>(getType(), getName() + suffix, *ptr,
TailCallInfo(tci));
}


Expand Down
Loading

0 comments on commit 0c8b3cf

Please sign in to comment.