Skip to content

Commit

Permalink
Add support for tail calls (#1090)
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniofrighetto authored Oct 4, 2024
1 parent 85b7ba5 commit 05f674c
Show file tree
Hide file tree
Showing 26 changed files with 163 additions and 63 deletions.
3 changes: 3 additions & 0 deletions ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class FnAttrs final {
std::optional<FPDenormalAttrs> fp_denormal32;
unsigned bits;
uint8_t allockind = 0;
bool is_tailcall = false;

public:
enum Attribute { None = 0, NNaN = 1 << 0, NoReturn = 1 << 1,
Expand Down Expand Up @@ -154,6 +155,8 @@ class FnAttrs final {
void add(AllocKind k) { allockind |= (uint8_t)k; }
bool has(AllocKind k) const { return allockind & (uint8_t)k; }
bool isAlloc() const { return allockind != 0 || has(AllocSize); }
void setTailCallSite(bool is_tc) { is_tailcall = is_tc; }
bool isTailCall() const { return is_tailcall; }

void inferImpliedAttributes();

Expand Down
78 changes: 62 additions & 16 deletions ir/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2295,7 +2295,8 @@ void FnCall::print(ostream &os) const {
if (!isVoid())
os << getName() << " = ";

os << "call " << print_type(getType())
os << (getAttributes().isTailCall() ? "tail " : "")
<< "call " << print_type(getType())
<< (fnptr ? fnptr->getName() : fnName) << '(';

bool first = true;
Expand All @@ -2313,6 +2314,31 @@ void FnCall::print(ostream &os) const {
os << ')' << attrs;
}

static void check_tailcall(const Instr &i, State &s) {
bool found = false;
const auto &instrs = s.getFn().bbOf(i).instrs();
auto it = instrs.begin();
for (auto e = instrs.end(); it != e; ++it) {
if (&*it == &i) {
found = true;
break;
}
}
assert(found);

++it;
auto &next_instr = *it;
if (auto *ret = dynamic_cast<const Return *>(&next_instr)) {
if (ret->getType().isVoid() && i.getType().isVoid())
return;
auto *ret_val = ret->operands()[0];
if (ret_val == &i)
return;
}

s.addUB(expr(false));
}

static void check_can_load(State &s, const expr &p0) {
auto &attrs = s.getFn().getFnAttrs();
if (attrs.mem.canReadAnything())
Expand Down Expand Up @@ -2493,6 +2519,9 @@ StateValue FnCall::toSMT(State &s) const {
!attrs.has(FnAttrs::WillReturn))
s.addGuardableUB(expr(false));

if (getAttributes().isTailCall())
check_tailcall(*this, s);

auto get_alloc_ptr = [&]() -> Value& {
for (auto &[arg, flags] : args) {
if (flags.has(ParamAttrs::AllocPtr))
Expand Down Expand Up @@ -4123,8 +4152,8 @@ void Memset::rauw(const Value &what, Value &with) {
}

void Memset::print(ostream &os) const {
os << "memset " << *ptr << " align " << align << ", " << *val
<< ", " << *bytes;
os << (isTailCall() ? "tail " : "") << "memset " << *ptr
<< " align " << align << ", " << *val << ", " << *bytes;
}

StateValue Memset::toSMT(State &s) const {
Expand All @@ -4144,6 +4173,9 @@ StateValue Memset::toSMT(State &s) const {
vptr = sv_ptr.value;
}
check_can_store(s, vptr);
if (isTailCall())
check_tailcall(*this, s);

s.getMemory().memset(vptr, s[*val].zextOrTrunc(8), vbytes, align,
s.getUndefVars());
return {};
Expand All @@ -4156,18 +4188,18 @@ expr Memset::getTypeConstraints(const Function &f) const {
}

unique_ptr<Instr> Memset::dup(Function &f, const string &suffix) const {
return make_unique<Memset>(*ptr, *val, *bytes, align);
return make_unique<Memset>(*ptr, *val, *bytes, align, is_tailcall);
}


DEFINE_AS_RETZEROALIGN(MemsetPattern, getMaxAllocSize);
DEFINE_AS_RETZERO(MemsetPattern, getMaxGEPOffset);

MemsetPattern::MemsetPattern(Value &ptr, Value &pattern, Value &bytes,
unsigned pattern_length)
unsigned pattern_length, bool is_tailcall)
: MemInstr(Type::voidTy, "memset_pattern" + to_string(pattern_length)),
ptr(&ptr), pattern(&pattern), bytes(&bytes),
pattern_length(pattern_length) {}
pattern_length(pattern_length), is_tailcall(is_tailcall) {}

uint64_t MemsetPattern::getMaxAccessSize() const {
return getIntOr(*bytes, UINT64_MAX);
Expand Down Expand Up @@ -4195,7 +4227,8 @@ void MemsetPattern::rauw(const Value &what, Value &with) {
}

void MemsetPattern::print(ostream &os) const {
os << getName() << ' ' << *ptr << ", " << *pattern << ", " << *bytes;
os << getName() << ' ' << (isTailCall() ? "tail " : "")
<< *ptr << ", " << *pattern << ", " << *bytes;
}

StateValue MemsetPattern::toSMT(State &s) const {
Expand All @@ -4204,6 +4237,9 @@ StateValue MemsetPattern::toSMT(State &s) const {
auto &vbytes = s.getAndAddPoisonUB(*bytes, true).value;
check_can_store(s, vptr);
check_can_load(s, vpattern);
if (isTailCall())
check_tailcall(*this, s);

s.getMemory().memset_pattern(vptr, vpattern, vbytes, pattern_length);
return {};
}
Expand All @@ -4215,7 +4251,7 @@ expr MemsetPattern::getTypeConstraints(const Function &f) const {
}

unique_ptr<Instr> MemsetPattern::dup(Function &f, const string &suffix) const {
return make_unique<MemsetPattern>(*ptr, *pattern, *bytes, pattern_length);
return make_unique<MemsetPattern>(*ptr, *pattern, *bytes, pattern_length, is_tailcall);
}


Expand Down Expand Up @@ -4297,8 +4333,9 @@ void Memcpy::rauw(const Value &what, Value &with) {
}

void Memcpy::print(ostream &os) const {
os << (move ? "memmove " : "memcpy ") << *dst << " align " << align_dst
<< ", " << *src << " align " << align_src << ", " << *bytes;
os << (isTailCall() ? "tail " : "") << (move ? "memmove " : "memcpy ")
<< *dst << " align " << align_dst << ", "
<< *src << " align " << align_src << ", " << *bytes;
}

StateValue Memcpy::toSMT(State &s) const {
Expand Down Expand Up @@ -4330,6 +4367,9 @@ StateValue Memcpy::toSMT(State &s) const {

check_can_load(s, vsrc);
check_can_store(s, vdst);
if (isTailCall())
check_tailcall(*this, s);

s.getMemory().memcpy(vdst, vsrc, vbytes, align_dst, align_src, move);
return {};
}
Expand All @@ -4341,7 +4381,7 @@ expr Memcpy::getTypeConstraints(const Function &f) const {
}

unique_ptr<Instr> Memcpy::dup(Function &f, const string &suffix) const {
return make_unique<Memcpy>(*dst, *src, *bytes, align_dst, align_src, move);
return make_unique<Memcpy>(*dst, *src, *bytes, align_dst, align_src, move, is_tailcall);
}


Expand Down Expand Up @@ -4374,8 +4414,9 @@ void Memcmp::rauw(const Value &what, Value &with) {
}

void Memcmp::print(ostream &os) const {
os << getName() << " = " << (is_bcmp ? "bcmp " : "memcmp ") << *ptr1
<< ", " << *ptr2 << ", " << *num;
os << getName() << " = " << (isTailCall() ? "tail " : "")
<< (is_bcmp ? "bcmp " : "memcmp ") << *ptr1 << ", "
<< *ptr2 << ", " << *num;
}

StateValue Memcmp::toSMT(State &s) const {
Expand All @@ -4386,6 +4427,8 @@ StateValue Memcmp::toSMT(State &s) const {

check_can_load(s, vptr1);
check_can_load(s, vptr2);
if (isTailCall())
check_tailcall(*this, s);

Pointer p1(s.getMemory(), vptr1), p2(s.getMemory(), vptr2);
// memcmp can be optimized to load & icmps, and it requires this
Expand Down Expand Up @@ -4449,7 +4492,7 @@ expr Memcmp::getTypeConstraints(const Function &f) const {

unique_ptr<Instr> Memcmp::dup(Function &f, const string &suffix) const {
return make_unique<Memcmp>(getType(), getName() + suffix, *ptr1, *ptr2, *num,
is_bcmp);
is_bcmp, is_tailcall);
}


Expand Down Expand Up @@ -4477,12 +4520,15 @@ void Strlen::rauw(const Value &what, Value &with) {
}

void Strlen::print(ostream &os) const {
os << getName() << " = strlen " << *ptr;
os << getName() << " = " << (isTailCall() ? "tail " : "")
<< "strlen " << *ptr;
}

StateValue Strlen::toSMT(State &s) const {
auto &eptr = s.getWellDefinedPtr(*ptr);
check_can_load(s, eptr);
if (isTailCall())
check_tailcall(*this, s);

Pointer p(s.getMemory(), eptr);
Type &ty = getType();
Expand All @@ -4509,7 +4555,7 @@ expr Strlen::getTypeConstraints(const Function &f) const {
}

unique_ptr<Instr> Strlen::dup(Function &f, const string &suffix) const {
return make_unique<Strlen>(getType(), getName() + suffix, *ptr);
return make_unique<Strlen>(getType(), getName() + suffix, *ptr, is_tailcall);
}


Expand Down
35 changes: 26 additions & 9 deletions ir/instr.h
Original file line number Diff line number Diff line change
Expand Up @@ -941,15 +941,18 @@ class Store final : public MemInstr {
class Memset final : public MemInstr {
Value *ptr, *val, *bytes;
uint64_t align;
bool is_tailcall;

public:
Memset(Value &ptr, Value &val, Value &bytes, uint64_t align)
Memset(Value &ptr, Value &val, Value &bytes, uint64_t align, bool is_tailcall)
: MemInstr(Type::voidTy, "memset"), ptr(&ptr), val(&val), bytes(&bytes),
align(align) {}
align(align), is_tailcall(is_tailcall) {}

Value& getPtr() const { return *ptr; }
Value& getBytes() const { return *bytes; }
uint64_t getAlign() const { return align; }
void setAlign(uint64_t align) { this->align = align; }
bool isTailCall() const { return is_tailcall; }

std::pair<uint64_t, uint64_t> getMaxAllocSize() const override;
uint64_t getMaxAccessSize() const override;
Expand All @@ -970,11 +973,14 @@ class Memset final : public MemInstr {
class MemsetPattern final : public MemInstr {
Value *ptr, *pattern, *bytes;
unsigned pattern_length;
bool is_tailcall;

public:
MemsetPattern(Value &ptr, Value &pattern, Value &bytes,
unsigned pattern_length);
unsigned pattern_length, bool is_tailcall);

unsigned getPatternLength() const { return pattern_length; }
bool isTailCall() const { return is_tailcall; }

std::pair<uint64_t, uint64_t> getMaxAllocSize() const override;
uint64_t getMaxAccessSize() const override;
Expand Down Expand Up @@ -1017,11 +1023,14 @@ class Memcpy final : public MemInstr {
Value *dst, *src, *bytes;
uint64_t align_dst, align_src;
bool move;
bool is_tailcall;

public:
Memcpy(Value &dst, Value &src, Value &bytes,
uint64_t align_dst, uint64_t align_src, bool move)
uint64_t align_dst, uint64_t align_src, bool move, bool is_tailcall)
: MemInstr(Type::voidTy, "memcpy"), dst(&dst), src(&src), bytes(&bytes),
align_dst(align_dst), align_src(align_src), move(move) {}
align_dst(align_dst), align_src(align_src), move(move),
is_tailcall(is_tailcall) {}

Value& getSrc() const { return *src; }
Value& getDst() const { return *dst; }
Expand All @@ -1031,6 +1040,7 @@ class Memcpy final : public MemInstr {
void setSrcAlign(uint64_t align) { align_src = align; }
void setDstAlign(uint64_t align) { align_dst = align; }
bool isMove() const { return move; }
bool isTailCall() const { return is_tailcall; }

std::pair<uint64_t, uint64_t> getMaxAllocSize() const override;
uint64_t getMaxAccessSize() const override;
Expand All @@ -1051,13 +1061,17 @@ class Memcpy final : public MemInstr {
class Memcmp final : public MemInstr {
Value *ptr1, *ptr2, *num;
bool is_bcmp;
bool is_tailcall;

public:
Memcmp(Type &type, std::string &&name, Value &ptr1, Value &ptr2, Value &num,
bool is_bcmp): MemInstr(type, std::move(name)), ptr1(&ptr1),
ptr2(&ptr2), num(&num), is_bcmp(is_bcmp) {}
bool is_bcmp, bool is_tailcall) : MemInstr(type, std::move(name)),
ptr1(&ptr1), ptr2(&ptr2), num(&num), is_bcmp(is_bcmp),
is_tailcall(is_tailcall) {}

Value &getBytes() const { return *num; }
bool isBCmp() const { return is_bcmp; }
bool isTailCall() const { return is_tailcall; }

std::pair<uint64_t, uint64_t> getMaxAllocSize() const override;
uint64_t getMaxAccessSize() const override;
Expand All @@ -1077,11 +1091,14 @@ class Memcmp final : public MemInstr {

class Strlen final : public MemInstr {
Value *ptr;
bool is_tailcall;

public:
Strlen(Type &type, std::string &&name, Value &ptr)
: MemInstr(type, std::move(name)), ptr(&ptr) {}
Strlen(Type &type, std::string &&name, Value &ptr, bool is_tailcall)
: MemInstr(type, std::move(name)), ptr(&ptr), is_tailcall(is_tailcall) {}

Value *getPointer() const { return ptr; }
bool isTailCall() const { return is_tailcall; }

std::pair<uint64_t, uint64_t> getMaxAllocSize() const override;
uint64_t getMaxAccessSize() const override;
Expand Down
13 changes: 7 additions & 6 deletions llvm_util/known_fns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,26 +520,27 @@ known_call(llvm::CallInst &i, const llvm::TargetLibraryInfo &TLI,
if (!decl || !TLI.getLibFunc(*decl, libfn))
RETURN_EXACT();

bool is_tailcall = i.isTailCall();
switch (libfn) {
case llvm::LibFunc_memset: // void* memset(void *ptr, int val, size_t bytes)
BB.addInstr(make_unique<Memset>(*args[0], *args[1], *args[2], 1));
BB.addInstr(make_unique<Memset>(*args[0], *args[1], *args[2], 1, is_tailcall));
RETURN_VAL(make_unique<UnaryOp>(*ty, value_name(i), *args[0],
UnaryOp::Copy));

// void memset_pattern4(void *ptr, void *pattern, size_t bytes)
case llvm::LibFunc_memset_pattern4:
RETURN_VAL(make_unique<MemsetPattern>(*args[0], *args[1], *args[2], 4));
RETURN_VAL(make_unique<MemsetPattern>(*args[0], *args[1], *args[2], 4, is_tailcall));
case llvm::LibFunc_memset_pattern8:
RETURN_VAL(make_unique<MemsetPattern>(*args[0], *args[1], *args[2], 8));
RETURN_VAL(make_unique<MemsetPattern>(*args[0], *args[1], *args[2], 8, is_tailcall));
case llvm::LibFunc_memset_pattern16:
RETURN_VAL(make_unique<MemsetPattern>(*args[0], *args[1], *args[2], 16));
RETURN_VAL(make_unique<MemsetPattern>(*args[0], *args[1], *args[2], 16, is_tailcall));
case llvm::LibFunc_strlen:
RETURN_VAL(make_unique<Strlen>(*ty, value_name(i), *args[0]));
RETURN_VAL(make_unique<Strlen>(*ty, value_name(i), *args[0], is_tailcall));
case llvm::LibFunc_memcmp:
case llvm::LibFunc_bcmp: {
RETURN_VAL(
make_unique<Memcmp>(*ty, value_name(i), *args[0], *args[1], *args[2],
libfn == llvm::LibFunc_bcmp));
libfn == llvm::LibFunc_bcmp, is_tailcall));
}
case llvm::LibFunc_ffs:
case llvm::LibFunc_ffsl:
Expand Down
6 changes: 4 additions & 2 deletions llvm_util/llvm2alive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,8 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
return error(i);

return make_unique<Memset>(*ptr, *val, *bytes,
i.getDestAlign().valueOrOne().value());
i.getDestAlign().valueOrOne().value(),
i.isTailCall());
}

RetTy visitMemTransferInst(llvm::MemTransferInst &i) {
Expand All @@ -485,7 +486,7 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
return make_unique<Memcpy>(*dst, *src, *bytes,
i.getDestAlign().valueOrOne().value(),
i.getSourceAlign().valueOrOne().value(),
isa<llvm::MemMoveInst>(&i));
isa<llvm::MemMoveInst>(&i), i.isTailCall());
}

RetTy visitICmpInst(llvm::ICmpInst &i) {
Expand Down Expand Up @@ -1756,6 +1757,7 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
handleRetAttrs(attrs_callsite.getAttributes(ret), attrs);
handleFnAttrs(attrs_callsite.getAttributes(fnidx), attrs);
attrs.mem &= handleMemAttrs(i.getMemoryEffects());
attrs.setTailCallSite(i.isTailCall());
attrs.inferImpliedAttributes();
}

Expand Down
2 changes: 1 addition & 1 deletion tests/alive-tv/asm/malloc-in-tgt.srctgt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ define i32 @src() {

define i32 @tgt() {
f:
%stack = tail call dereferenceable(64) ptr @myalloc(i32 64)
%stack = call dereferenceable(64) ptr @myalloc(i32 64)
%0 = getelementptr inbounds i8, ptr %stack, i64 48
%1 = ptrtoint ptr %0 to i64
%a0_38 = add i64 %1, -16
Expand Down
Loading

0 comments on commit 05f674c

Please sign in to comment.