Skip to content

Commit 4a3c726

Browse files
authored
Fix readonly or throw to have local notions (#2416)
* Fix readonly or throw to have local notions * also loop cache to compile time error rather than assertion * fmt * fix * fmt * fmt * change to local * fix * fix
1 parent 0f88e4f commit 4a3c726

File tree

6 files changed

+191
-41
lines changed

6 files changed

+191
-41
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,7 +1961,10 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
19611961
if (CB->onlyAccessesInaccessibleMemory())
19621962
AARes = ModRefInfo::NoModRef;
19631963

1964-
bool ReadOnly = isReadOnlyOrThrow(CB);
1964+
bool ReadOnly = isLocalReadOnlyOrThrow(CB);
1965+
if (CB->hasStructRetAttr() &&
1966+
getBaseObject(CB->getArgOperand(0)) == getBaseObject(Val))
1967+
ReadOnly = false;
19651968

19661969
bool WriteOnly = isWriteOnly(CB);
19671970

@@ -3008,7 +3011,10 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR,
30083011

30093012
mayCapture |= !NoCapture;
30103013

3011-
bool ReadOnly = isReadOnlyOrThrow(call) || isReadOnly(call, idx);
3014+
bool ReadOnly = isReadOnly(call, idx);
3015+
if (!ReadOnly && isLocalReadOnlyOrThrow(call) && idx != 0 &&
3016+
call->hasStructRetAttr())
3017+
ReadOnly = true;
30123018

30133019
mayWrite |= !ReadOnly;
30143020

@@ -3392,8 +3398,9 @@ bool ActivityAnalyzer::isValueActivelyStoredOrReturned(TypeResults const &TR,
33923398

33933399
if (auto inst = dyn_cast<Instruction>(a)) {
33943400
if (!inst->mayWriteToMemory() ||
3395-
(isa<CallInst>(inst) && (AA.onlyReadsMemory(cast<CallInst>(inst)) ||
3396-
isReadOnlyOrThrow(cast<CallInst>(inst))))) {
3401+
(isa<CallInst>(inst) &&
3402+
(AA.onlyReadsMemory(cast<CallInst>(inst)) ||
3403+
isLocalReadOnlyOrThrow(cast<CallInst>(inst))))) {
33973404
// if not written to memory and returning a known constant, this
33983405
// cannot be actively returned/stored
33993406
if (inst->getParent()->getParent() == TR.getFunction() &&

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,7 +1524,7 @@ void SplitPHIs(llvm::Function &F) {
15241524
// returns if newly legal, subject to the pending calls
15251525
bool DetectReadonlyOrThrowFn(llvm::Function &F,
15261526
SmallPtrSetImpl<Function *> &calls_todo,
1527-
llvm::TargetLibraryInfo &TLI) {
1527+
llvm::TargetLibraryInfo &TLI, bool &local) {
15281528
if (isReadOnlyOrThrow(&F))
15291529
return false;
15301530
if (F.empty())
@@ -1539,8 +1539,10 @@ bool DetectReadonlyOrThrowFn(llvm::Function &F,
15391539
continue;
15401540
if (hasMetadata(&I, "enzyme_ReadOnlyOrThrow"))
15411541
continue;
1542+
if (hasMetadata(&I, "enzyme_LocalReadOnlyOrThrow"))
1543+
continue;
15421544
if (auto CI = dyn_cast<CallBase>(&I)) {
1543-
if (isReadOnlyOrThrow(CI)) {
1545+
if (isLocalReadOnlyOrThrow(CI)) {
15441546
continue;
15451547
}
15461548
if (isAllocationCall(CI, TLI)) {
@@ -1574,26 +1576,80 @@ bool DetectReadonlyOrThrowFn(llvm::Function &F,
15741576
// seen outside the function. Note, even if one stored into x =
15751577
// malloc(..), and stored x into a global/arg pointer, that second store
15761578
// would trigger not readonly.
1577-
if (isa<AllocaInst>(Obj) || isAllocationCall(Obj, TLI))
1579+
if (isa<AllocaInst>(Obj))
1580+
continue;
1581+
if (isAllocationCall(Obj, TLI)) {
1582+
if (local)
1583+
continue;
1584+
if (notCaptured(Obj))
1585+
continue;
1586+
local = true;
15781587
continue;
1588+
}
1589+
if (auto arg = dyn_cast<Argument>(Obj)) {
1590+
if (arg->hasStructRetAttr() ||
1591+
arg->getParent()
1592+
->getAttribute(arg->getArgNo() + AttributeList::FirstArgIndex,
1593+
"enzymejl_returnRoots")
1594+
.isValid()) {
1595+
local = true;
1596+
continue;
1597+
}
1598+
}
15791599
}
15801600
if (auto MTI = dyn_cast<MemTransferInst>(&I)) {
15811601
auto Obj = getBaseObject(MTI->getOperand(0));
15821602
// Storing into local memory is fine since it definitionally will not be
15831603
// seen outside the function. Note, even if one stored into x =
15841604
// malloc(..), and stored x into a global/arg pointer, that second store
15851605
// would trigger not readonly.
1586-
if (isa<AllocaInst>(Obj) || isAllocationCall(Obj, TLI))
1606+
if (isa<AllocaInst>(Obj))
1607+
continue;
1608+
if (isAllocationCall(Obj, TLI)) {
1609+
if (local)
1610+
continue;
1611+
if (notCaptured(Obj))
1612+
continue;
1613+
local = true;
15871614
continue;
1615+
}
1616+
if (auto arg = dyn_cast<Argument>(Obj)) {
1617+
if (arg->hasStructRetAttr() ||
1618+
arg->getParent()
1619+
->getAttribute(arg->getArgNo() + AttributeList::FirstArgIndex,
1620+
"enzymejl_returnRoots")
1621+
.isValid()) {
1622+
local = true;
1623+
continue;
1624+
}
1625+
}
15881626
}
15891627
if (auto MSI = dyn_cast<MemSetInst>(&I)) {
15901628
auto Obj = getBaseObject(MSI->getOperand(0));
15911629
// Storing into local memory is fine since it definitionally will not be
15921630
// seen outside the function. Note, even if one stored into x =
15931631
// malloc(..), and stored x into a global/arg pointer, that second store
15941632
// would trigger not readonly.
1595-
if (isa<AllocaInst>(Obj) || isAllocationCall(Obj, TLI))
1633+
if (isa<AllocaInst>(Obj))
1634+
continue;
1635+
if (isAllocationCall(Obj, TLI)) {
1636+
if (local)
1637+
continue;
1638+
if (notCaptured(Obj))
1639+
continue;
1640+
local = true;
15961641
continue;
1642+
}
1643+
if (auto arg = dyn_cast<Argument>(Obj)) {
1644+
if (arg->hasStructRetAttr() ||
1645+
arg->getParent()
1646+
->getAttribute(arg->getArgNo() + AttributeList::FirstArgIndex,
1647+
"enzymejl_returnRoots")
1648+
.isValid()) {
1649+
local = true;
1650+
continue;
1651+
}
1652+
}
15971653
}
15981654
// ignore atomic load impacts
15991655
if (isa<LoadInst>(&I))
@@ -1620,7 +1676,10 @@ bool DetectReadonlyOrThrowFn(llvm::Function &F,
16201676
}
16211677

16221678
if (calls_todo.size() == 0) {
1623-
F.addFnAttr("enzyme_ReadOnlyOrThrow");
1679+
if (local)
1680+
F.addFnAttr("enzyme_LocalReadOnlyOrThrow");
1681+
else
1682+
F.addFnAttr("enzyme_ReadOnlyOrThrow");
16241683
}
16251684
return true;
16261685
}
@@ -1629,6 +1688,17 @@ bool DetectReadonlyOrThrow(Module &M) {
16291688

16301689
bool changed = false;
16311690

1691+
PassBuilder PB;
1692+
LoopAnalysisManager LAM;
1693+
FunctionAnalysisManager FAM;
1694+
CGSCCAnalysisManager CGAM;
1695+
ModuleAnalysisManager MAM;
1696+
PB.registerModuleAnalyses(MAM);
1697+
PB.registerFunctionAnalyses(FAM);
1698+
PB.registerLoopAnalyses(LAM);
1699+
PB.registerCGSCCAnalyses(CGAM);
1700+
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
1701+
16321702
// Set of functions newly deduced readonlyorthrow by this pass
16331703
SmallVector<llvm::Function *> todo;
16341704

@@ -1640,21 +1710,15 @@ bool DetectReadonlyOrThrow(Module &M) {
16401710
// prerequisite for being readonly. Inverse of `todo_map`
16411711
DenseMap<llvm::Function *, SmallPtrSet<Function *, 1>> inverse_todo_map;
16421712

1643-
PassBuilder PB;
1644-
LoopAnalysisManager LAM;
1645-
FunctionAnalysisManager FAM;
1646-
CGSCCAnalysisManager CGAM;
1647-
ModuleAnalysisManager MAM;
1648-
PB.registerModuleAnalyses(MAM);
1649-
PB.registerFunctionAnalyses(FAM);
1650-
PB.registerLoopAnalyses(LAM);
1651-
PB.registerCGSCCAnalyses(CGAM);
1652-
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
1713+
SmallPtrSet<Function *, 1> LocalReadOnlyFunctions;
16531714

16541715
for (Function &F : M) {
16551716
SmallPtrSet<Function *, 1> calls_todo;
16561717
auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
1657-
if (DetectReadonlyOrThrowFn(F, calls_todo, TLI)) {
1718+
bool local = false;
1719+
if (DetectReadonlyOrThrowFn(F, calls_todo, TLI, local)) {
1720+
if (local)
1721+
LocalReadOnlyFunctions.insert(&F);
16581722
if (calls_todo.size() == 0) {
16591723
changed = true;
16601724
todo.push_back(&F);
@@ -1681,7 +1745,10 @@ bool DetectReadonlyOrThrow(Module &M) {
16811745
auto &fwd_set = found2->second;
16821746
fwd_set.erase(cur);
16831747
if (fwd_set.size() == 0) {
1684-
F2->addFnAttr("enzyme_ReadOnlyOrThrow");
1748+
if (LocalReadOnlyFunctions.contains(F2))
1749+
F2->addFnAttr("enzyme_LocalReadOnlyOrThrow");
1750+
else
1751+
F2->addFnAttr("enzyme_ReadOnlyOrThrow");
16851752
todo.push_back(F2);
16861753
todo_map.erase(F2);
16871754
}
@@ -2340,8 +2407,9 @@ Function *PreProcessCache::preprocessForClone(Function *F,
23402407

23412408
{
23422409
SmallPtrSet<Function *, 1> calls_todo;
2410+
bool local = false;
23432411
DetectReadonlyOrThrowFn(*NewF, calls_todo,
2344-
FAM.getResult<TargetLibraryAnalysis>(*NewF));
2412+
FAM.getResult<TargetLibraryAnalysis>(*NewF), local);
23452413
}
23462414

23472415
if (EnzymePrint)

enzyme/Enzyme/FunctionUtils.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,13 @@ extern llvm::cl::opt<bool> EnzymeAlwaysInlineDiff;
6464
// Perform an analysis to detect functions which only write to visible memory
6565
// outside the function if an error is not throw. Such a function can touch
6666
// inaccessible memory [e.g. the insides of malloc/etc], and the only violation
67-
// is whether existing memory before the call is written to. In other words,
68-
// malloc, calloc, copy_array, and friends, are all considered
67+
// is whether existing memory before the call is written to.
68+
// If non-local, returning memory written to is a violation (since it writes to
69+
// externally visible memory).
70+
// If local, returning memory written to is fine (since existing memory before
71+
// the call remains unchanged).
72+
// In other words, malloc [local and non-local], calloc [local and non-local],
73+
// copy_array [local only], and friends, are all considered
6974
// readonly_or_throw, as they only either read externally visible state, throw
7075
// an error, or write to inaccesible memory.
7176
bool DetectReadonlyOrThrow(llvm::Module &M);

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2750,14 +2750,27 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
27502750
assert(innerType == Type::getInt8Ty(malloc->getContext()));
27512751
} else {
27522752
if (innerType != malloc->getType()) {
2753-
llvm::errs() << *oldFunc << "\n";
2754-
llvm::errs() << *newFunc << "\n";
2755-
llvm::errs() << "innerType: " << *innerType << "\n";
2756-
llvm::errs() << "malloc->getType(): " << *malloc->getType() << "\n";
2757-
llvm::errs() << "ret: " << *ret << " - " << *ret->getType() << "\n";
2758-
llvm::errs() << "malloc: " << *malloc << "\n";
2759-
assert(0 && "illegal loop cache type");
2760-
llvm_unreachable("illegal loop cache type");
2753+
std::string str;
2754+
raw_string_ostream ss(str);
2755+
ss << "Illegal loop cache type:\n";
2756+
ss << *oldFunc << "\n";
2757+
ss << *newFunc << "\n";
2758+
ss << "innerType: " << *innerType << "\n";
2759+
ss << "malloc->getType(): " << *malloc->getType() << "\n";
2760+
ss << "ret: " << *ret << " - " << *ret->getType() << "\n";
2761+
ss << "malloc: " << *malloc << "\n";
2762+
if (CustomErrorHandler) {
2763+
CustomErrorHandler(str.c_str(), wrap(malloc),
2764+
ErrorType::InternalError, nullptr, nullptr,
2765+
nullptr);
2766+
} else {
2767+
DebugLoc loc;
2768+
if (auto I = dyn_cast<Instruction>(malloc))
2769+
EmitFailure("LoopCache", I->getDebugLoc(), I, ss.str());
2770+
else
2771+
EmitFailure("LoopCache", DebugLoc(), newFunc, ss.str());
2772+
}
2773+
return UndefValue::get(malloc->getType());
27612774
}
27622775
}
27632776

enzyme/Enzyme/Utils.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3867,7 +3867,7 @@ bool notCapturedBefore(llvm::Value *V, Instruction *inst,
38673867
else
38683868
VI = VI->getNextNode();
38693869
SmallPtrSet<BasicBlock *, 1> regionBetween;
3870-
{
3870+
if (inst) {
38713871
SmallVector<BasicBlock *, 1> todo;
38723872
todo.push_back(VI->getParent());
38733873
while (todo.size()) {
@@ -3893,15 +3893,17 @@ bool notCapturedBefore(llvm::Value *V, Instruction *inst,
38933893
auto UI = std::get<0>(pair);
38943894
auto level = std::get<1>(pair);
38953895
auto prev = std::get<2>(pair);
3896-
if (!regionBetween.count(UI->getParent()))
3897-
continue;
3898-
if (UI->getParent() == VI->getParent()) {
3899-
if (UI->comesBefore(VI))
3896+
if (inst) {
3897+
if (!regionBetween.count(UI->getParent()))
39003898
continue;
3899+
if (UI->getParent() == VI->getParent()) {
3900+
if (UI->comesBefore(VI))
3901+
continue;
3902+
}
3903+
if (UI->getParent() == inst->getParent())
3904+
if (inst->comesBefore(UI))
3905+
continue;
39013906
}
3902-
if (UI->getParent() == inst->getParent())
3903-
if (inst->comesBefore(UI))
3904-
continue;
39053907

39063908
if (isPointerArithmeticInst(UI, /*includephi*/ true,
39073909
/*includebin*/ true)) {
@@ -3961,6 +3963,8 @@ bool notCapturedBefore(llvm::Value *V, Instruction *inst,
39613963
return true;
39623964
}
39633965

3966+
bool notCaptured(llvm::Value *V) { return notCapturedBefore(V, nullptr, 0); }
3967+
39643968
// Return true if guaranteed not to alias
39653969
// Return false if guaranteed to alias [with possible offset depending on flag].
39663970
// Return {} if no information is given.

enzyme/Enzyme/Utils.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,6 +1646,56 @@ static inline bool isReadOnly(const llvm::CallBase *call, ssize_t arg = -1) {
16461646
return false;
16471647
}
16481648

1649+
// Whether the function does not write to memory visible before the function in
1650+
// all cases that it doesn't error. In other words, the legal operations here
1651+
// are:
1652+
//. 1) Throw [in which case any operation guaranteed to throw is valid]
1653+
//. 2) Read from any memory
1654+
//. 3) Write to memory which did not exist did not exist prior to the function
1655+
// call. This means that one can write . to memory whose allocation happened
1656+
// within the call to F (including a local alloca, a malloc call, even if .
1657+
// returned). This is also legal to write to an sret and/or returnroots
1658+
// parameter (which must be an alloca).
1659+
static inline bool isLocalReadOnlyOrThrow(const llvm::Function *F) {
1660+
if (isReadOnly(F))
1661+
return true;
1662+
1663+
if (F->hasFnAttribute("enzyme_LocalReadOnlyOrThrow") ||
1664+
F->hasFnAttribute("enzyme_ReadOnlyOrThrow"))
1665+
return true;
1666+
1667+
return false;
1668+
}
1669+
1670+
static inline bool isLocalReadOnlyOrThrow(const llvm::CallBase *call) {
1671+
if (isReadOnly(call))
1672+
return true;
1673+
1674+
if (call->hasFnAttr("enzyme_LocalReadOnlyOrThrow") ||
1675+
call->hasFnAttr("enzyme_ReadOnlyOrThrow"))
1676+
return true;
1677+
1678+
if (auto F = getFunctionFromCall(call)) {
1679+
// Do not use function attrs for if different calling conv, such as a julia
1680+
// call wrapping args into an array. This is because the wrapped array
1681+
// may be nocapure/readonly, but the actual arg (which will be put in the
1682+
// array) may not be.
1683+
if (F->getCallingConv() == call->getCallingConv())
1684+
if (isLocalReadOnlyOrThrow(F))
1685+
return true;
1686+
}
1687+
return false;
1688+
}
1689+
1690+
// Whether the function does not write to memory visible outside the function in
1691+
// all cases that it doesn't error. In other words, the legal operations here
1692+
// are:
1693+
//. 1) Throw [in which case any operation guaranteed to throw is valid]
1694+
//. 2) Read from any memory
1695+
//. 3) Write to memory which did not exist did not exist prior to the function
1696+
// call. This means that one can write . to memory whose lifetime is
1697+
// entirely contained within F (including a local alloca, a malloc call locally
1698+
// freed, but not . a returned malloc call).
16491699
static inline bool isReadOnlyOrThrow(const llvm::Function *F) {
16501700
if (isReadOnly(F))
16511701
return true;
@@ -2265,6 +2315,9 @@ bool isNVLoad(const llvm::Value *V);
22652315
bool notCapturedBefore(llvm::Value *V, llvm::Instruction *inst,
22662316
size_t checkLoadCaptured);
22672317

2318+
//! Check if value if b captured
2319+
bool notCaptured(llvm::Value *V);
2320+
22682321
// Return true if guaranteed not to alias
22692322
// Return false if guaranteed to alias [with possible offset depending on flag].
22702323
// Return {} if no information is given.

0 commit comments

Comments
 (0)