Skip to content

Commit fe2b323

Browse files
authored
Rewire return type generation to explicitly mark primal/shadow (#2425)
* Mark readonly or throw as no-propagate * Rewire return type generation to explicitly mark primal/shadow * fix * fix * fmt * fix
1 parent 3a44f70 commit fe2b323

File tree

7 files changed

+65
-140
lines changed

7 files changed

+65
-140
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5005,14 +5005,10 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
50055005
auto ft = call.getFunctionType();
50065006
bool retActive = subretType != DIFFE_TYPE::CONSTANT;
50075007

5008-
ReturnType subretVal =
5009-
subretused
5010-
? (retActive ? ReturnType::TwoReturns : ReturnType::Return)
5011-
: (retActive ? ReturnType::Return : ReturnType::Void);
5012-
50135008
FT = getFunctionTypeForClone(
50145009
ft, Mode, gutils->getWidth(), tape ? tape->getType() : nullptr,
5015-
argsInverted, false, subretVal, subretType);
5010+
argsInverted, false, /*returnTape*/ false,
5011+
/*returnPrimal*/ subretused, /*returnShadow*/ retActive);
50165012
PointerType *fptype = PointerType::getUnqual(FT);
50175013
newcalled = BuilderZ.CreatePointerCast(newcalled,
50185014
PointerType::getUnqual(fptype));

enzyme/Enzyme/DiffeGradientUtils.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
8989
bool strongZero, unsigned width, Function *todiff, TargetLibraryInfo &TLI,
9090
TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType,
9191
bool shadowReturn, bool diffeReturnArg, ArrayRef<DIFFE_TYPE> constant_args,
92-
ReturnType returnValue, Type *additionalArg, bool omp) {
92+
bool returnTape, bool returnPrimal, Type *additionalArg, bool omp) {
9393
Function *oldFunc = todiff;
9494
assert(mode == DerivativeMode::ReverseModeGradient ||
9595
mode == DerivativeMode::ReverseModeCombined ||
@@ -126,7 +126,8 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
126126

127127
auto newFunc = Logic.PPC.CloneFunctionWithReturns(
128128
mode, width, oldFunc, invertedPointers, constant_args, constant_values,
129-
nonconstant_values, returnvals, returnValue, retType,
129+
nonconstant_values, returnvals, returnTape, returnPrimal,
130+
(mode == DerivativeMode::ReverseModeGradient) ? false : shadowReturn,
130131
prefix + oldFunc->getName(), &originalToNew,
131132
/*diffeReturnArg*/ diffeReturnArg, additionalArg);
132133

enzyme/Enzyme/DiffeGradientUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ class DiffeGradientUtils final : public GradientUtils {
8383
llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA,
8484
FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType,
8585
bool shadowReturnArg, bool diffeReturnArg,
86-
llvm::ArrayRef<DIFFE_TYPE> constant_args,
87-
ReturnType returnValue, llvm::Type *additionalArg, bool omp);
86+
llvm::ArrayRef<DIFFE_TYPE> constant_args, bool returnTape,
87+
bool returnPrimal, llvm::Type *additionalArg, bool omp);
8888

8989
llvm::AllocaInst *getDifferential(llvm::Value *val);
9090

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 28 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3166,7 +3166,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
31663166
}
31673167

31683168
void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,
3169-
DIFFE_TYPE retType, ReturnType retVal) {
3169+
DIFFE_TYPE retType, bool returnPrimal,
3170+
bool returnShadow) {
31703171
TypeResults &TR = gutils->TR;
31713172
ReturnInst *inst = dyn_cast<ReturnInst>(oBB->getTerminator());
31723173
// In forward mode we only need to update the return value
@@ -3212,75 +3213,45 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,
32123213
}
32133214
}
32143215

3215-
switch (retVal) {
3216-
case ReturnType::Return: {
3217-
auto ret = inst->getOperand(0);
3218-
3219-
Type *rt = ret->getType();
3220-
while (auto AT = dyn_cast<ArrayType>(rt))
3221-
rt = AT->getElementType();
3222-
bool floatLike = rt->isFPOrFPVectorTy();
3223-
3224-
if (retType == DIFFE_TYPE::CONSTANT) {
3225-
toret = gutils->getNewFromOriginal(ret);
3226-
} else if (!floatLike &&
3227-
TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
3228-
toret = invertedPtr ? invertedPtr : gutils->invertPointerM(ret, nBuilder);
3229-
} else if (!gutils->isConstantValue(ret)) {
3230-
assert(!invertedPtr);
3231-
toret = gutils->diffe(ret, nBuilder);
3232-
} else {
3233-
toret = invertedPtr
3234-
? invertedPtr
3235-
: gutils->invertPointerM(ret, nBuilder, /*nullInit*/ true);
3236-
}
3216+
Value *primal = nullptr;
3217+
Value *shadow = nullptr;
32373218

3238-
break;
3219+
if (returnPrimal) {
3220+
auto ret = inst->getOperand(0);
3221+
primal = gutils->getNewFromOriginal(ret);
32393222
}
3240-
case ReturnType::TwoReturns: {
3241-
if (retType == DIFFE_TYPE::CONSTANT)
3242-
assert(false && "Invalid return type");
3223+
if (returnShadow) {
32433224
auto ret = inst->getOperand(0);
3244-
32453225
Type *rt = ret->getType();
32463226
while (auto AT = dyn_cast<ArrayType>(rt))
32473227
rt = AT->getElementType();
32483228
bool floatLike = rt->isFPOrFPVectorTy();
32493229

3250-
toret =
3251-
nBuilder.CreateInsertValue(toret, gutils->getNewFromOriginal(ret), 0);
3252-
32533230
if (!floatLike && TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
3254-
toret = nBuilder.CreateInsertValue(
3255-
toret,
3256-
invertedPtr ? invertedPtr : gutils->invertPointerM(ret, nBuilder), 1);
3231+
shadow =
3232+
invertedPtr ? invertedPtr : gutils->invertPointerM(ret, nBuilder);
32573233
} else if (!gutils->isConstantValue(ret)) {
32583234
assert(!invertedPtr);
3259-
toret =
3260-
nBuilder.CreateInsertValue(toret, gutils->diffe(ret, nBuilder), 1);
3235+
shadow = gutils->diffe(ret, nBuilder);
32613236
} else {
3262-
toret = nBuilder.CreateInsertValue(
3263-
toret,
3264-
invertedPtr
3265-
? invertedPtr
3266-
: gutils->invertPointerM(ret, nBuilder, /*nullInit*/ true),
3267-
1);
3237+
shadow = invertedPtr
3238+
? invertedPtr
3239+
: gutils->invertPointerM(ret, nBuilder, /*nullInit*/ true);
32683240
}
3269-
break;
32703241
}
3271-
case ReturnType::Void: {
3242+
3243+
if (primal && shadow) {
3244+
toret = nBuilder.CreateInsertValue(toret, primal, 0);
3245+
toret = nBuilder.CreateInsertValue(toret, shadow, 1);
3246+
} else if (primal) {
3247+
toret = primal;
3248+
} else if (shadow) {
3249+
toret = shadow;
3250+
} else {
32723251
gutils->erase(gutils->getNewFromOriginal(inst));
32733252
nBuilder.CreateRetVoid();
32743253
return;
32753254
}
3276-
default: {
3277-
llvm::errs() << "Invalid return type: " << to_string(retVal)
3278-
<< "for function: \n"
3279-
<< gutils->newFunc << "\n";
3280-
assert(false && "Invalid return type for function");
3281-
return;
3282-
}
3283-
}
32843255

32853256
gutils->erase(newInst);
32863257
nBuilder.CreateRet(toret);
@@ -4234,19 +4205,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
42344205
assert(augmenteddata->constant_args == key.constant_args);
42354206
}
42364207

4237-
ReturnType retVal =
4238-
key.returnUsed ? (key.shadowReturnUsed ? ReturnType::ArgsWithTwoReturns
4239-
: ReturnType::ArgsWithReturn)
4240-
: (key.shadowReturnUsed ? ReturnType::ArgsWithReturn
4241-
: ReturnType::Args);
4242-
42434208
bool diffeReturnArg = key.retType == DIFFE_TYPE::OUT_DIFF;
42444209

42454210
DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone(
42464211
*this, key.mode, key.runtimeActivity, key.strongZero, key.width,
42474212
key.todiff, TLI, TA, oldTypeInfo, key.retType,
42484213
augmenteddata ? augmenteddata->shadowReturnUsed : key.shadowReturnUsed,
4249-
diffeReturnArg, key.constant_args, retVal, key.additionalType, omp);
4214+
diffeReturnArg, key.constant_args, /*returnTape*/ false, key.returnUsed,
4215+
key.additionalType, omp);
42504216

42514217
gutils->AtomicAdd = key.AtomicAdd;
42524218
gutils->FreeMemory = key.freeMemory;
@@ -4904,17 +4870,13 @@ Function *EnzymeLogic::CreateForwardDiff(
49044870

49054871
bool retActive = retType != DIFFE_TYPE::CONSTANT;
49064872

4907-
ReturnType retVal =
4908-
returnUsed ? (retActive ? ReturnType::TwoReturns : ReturnType::Return)
4909-
: (retActive ? ReturnType::Return : ReturnType::Void);
4910-
49114873
bool diffeReturnArg = false;
49124874

49134875
DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone(
49144876
*this, mode, runtimeActivity, strongZero, width, todiff, TLI, TA,
49154877
oldTypeInfo, retType,
4916-
/*shadowReturn*/ retActive, diffeReturnArg, constant_args, retVal,
4917-
additionalArg, omp);
4878+
/*shadowReturn*/ retActive, diffeReturnArg, constant_args,
4879+
/*returnTape*/ false, returnUsed, additionalArg, omp);
49184880

49194881
insert_or_assign2<ForwardCacheKey, Function *>(ForwardCachedFunctions, tup,
49204882
gutils->newFunc);
@@ -5092,7 +5054,7 @@ Function *EnzymeLogic::CreateForwardDiff(
50925054
maker->visit(&*it);
50935055
}
50945056

5095-
createTerminator(gutils, &oBB, retType, retVal);
5057+
createTerminator(gutils, &oBB, retType, returnUsed, retActive);
50965058
}
50975059

50985060
if (mode == DerivativeMode::ForwardModeSplit && augmenteddata)

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 25 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2423,31 +2423,18 @@ Function *PreProcessCache::preprocessForClone(Function *F,
24232423
return NewF;
24242424
}
24252425

2426-
FunctionType *getFunctionTypeForClone(
2427-
llvm::FunctionType *FTy, DerivativeMode mode, unsigned width,
2428-
llvm::Type *additionalArg, llvm::ArrayRef<DIFFE_TYPE> constant_args,
2429-
bool diffeReturnArg, ReturnType returnValue, DIFFE_TYPE returnType) {
2426+
FunctionType *getFunctionTypeForClone(llvm::FunctionType *FTy,
2427+
DerivativeMode mode, unsigned width,
2428+
llvm::Type *additionalArg,
2429+
llvm::ArrayRef<DIFFE_TYPE> constant_args,
2430+
bool diffeReturnArg, bool returnTape,
2431+
bool returnPrimal, bool returnShadow) {
24302432
SmallVector<Type *, 4> RetTypes;
2431-
if (returnValue == ReturnType::ArgsWithReturn ||
2432-
returnValue == ReturnType::Return) {
2433-
if (returnType != DIFFE_TYPE::CONSTANT &&
2434-
returnType != DIFFE_TYPE::OUT_DIFF) {
2435-
RetTypes.push_back(
2436-
GradientUtils::getShadowType(FTy->getReturnType(), width));
2437-
} else {
2438-
RetTypes.push_back(FTy->getReturnType());
2439-
}
2440-
} else if (returnValue == ReturnType::ArgsWithTwoReturns ||
2441-
returnValue == ReturnType::TwoReturns) {
2433+
if (returnPrimal)
24422434
RetTypes.push_back(FTy->getReturnType());
2443-
if (returnType != DIFFE_TYPE::CONSTANT &&
2444-
returnType != DIFFE_TYPE::OUT_DIFF) {
2445-
RetTypes.push_back(
2446-
GradientUtils::getShadowType(FTy->getReturnType(), width));
2447-
} else {
2448-
RetTypes.push_back(FTy->getReturnType());
2449-
}
2450-
}
2435+
if (returnShadow)
2436+
RetTypes.push_back(
2437+
GradientUtils::getShadowType(FTy->getReturnType(), width));
24512438
SmallVector<Type *, 4> ArgTypes;
24522439

24532440
// The user might be deleting arguments to the function by specifying them in
@@ -2459,7 +2446,7 @@ FunctionType *getFunctionTypeForClone(
24592446
if (constant_args[argno] == DIFFE_TYPE::DUP_ARG ||
24602447
constant_args[argno] == DIFFE_TYPE::DUP_NONEED) {
24612448
ArgTypes.push_back(GradientUtils::getShadowType(I, width));
2462-
} else if (constant_args[argno] == DIFFE_TYPE::OUT_DIFF) {
2449+
} else if (constant_args[argno] == DIFFE_TYPE::OUT_DIFF && !returnTape) {
24632450
RetTypes.push_back(GradientUtils::getShadowType(I, width));
24642451
}
24652452
++argno;
@@ -2474,31 +2461,19 @@ FunctionType *getFunctionTypeForClone(
24742461
ArgTypes.push_back(additionalArg);
24752462
}
24762463
Type *RetType = StructType::get(FTy->getContext(), RetTypes);
2477-
if (returnValue == ReturnType::TapeAndTwoReturns ||
2478-
returnValue == ReturnType::TapeAndReturn ||
2479-
returnValue == ReturnType::Tape) {
2480-
RetTypes.clear();
2481-
RetTypes.push_back(getDefaultAnonymousTapeType(FTy->getContext()));
2482-
if (returnValue == ReturnType::TapeAndTwoReturns) {
2483-
RetTypes.push_back(FTy->getReturnType());
2484-
RetTypes.push_back(
2485-
GradientUtils::getShadowType(FTy->getReturnType(), width));
2486-
} else if (returnValue == ReturnType::TapeAndReturn) {
2487-
if (returnType != DIFFE_TYPE::CONSTANT &&
2488-
returnType != DIFFE_TYPE::OUT_DIFF)
2489-
RetTypes.push_back(
2490-
GradientUtils::getShadowType(FTy->getReturnType(), width));
2491-
else
2492-
RetTypes.push_back(FTy->getReturnType());
2493-
}
2494-
RetType = StructType::get(FTy->getContext(), RetTypes);
2495-
} else if (returnValue == ReturnType::Return) {
2496-
assert(RetTypes.size() == 1);
2497-
RetType = RetTypes[0];
2498-
} else if (returnValue == ReturnType::TwoReturns) {
2499-
assert(RetTypes.size() == 2);
2464+
if (returnTape) {
2465+
RetTypes.insert(RetTypes.begin(),
2466+
getDefaultAnonymousTapeType(FTy->getContext()));
25002467
}
25012468

2469+
if (RetTypes.size() == 0)
2470+
RetType = Type::getVoidTy(RetType->getContext());
2471+
else if (RetTypes.size() == 1 && (returnPrimal || returnShadow) &&
2472+
mode != DerivativeMode::ReverseModeCombined)
2473+
RetType = RetTypes[0];
2474+
else
2475+
RetType = StructType::get(FTy->getContext(), RetTypes);
2476+
25022477
bool noReturn = RetTypes.size() == 0;
25032478
if (noReturn)
25042479
RetType = Type::getVoidTy(RetType->getContext());
@@ -2511,16 +2486,16 @@ Function *PreProcessCache::CloneFunctionWithReturns(
25112486
DerivativeMode mode, unsigned width, Function *&F,
25122487
ValueToValueMapTy &ptrInputs, ArrayRef<DIFFE_TYPE> constant_args,
25132488
SmallPtrSetImpl<Value *> &constants, SmallPtrSetImpl<Value *> &nonconstant,
2514-
SmallPtrSetImpl<Value *> &returnvals, ReturnType returnValue,
2515-
DIFFE_TYPE returnType, const Twine &name,
2489+
SmallPtrSetImpl<Value *> &returnvals, bool returnTape, bool returnPrimal,
2490+
bool returnShadow, const Twine &name,
25162491
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> *VMapO,
25172492
bool diffeReturnArg, llvm::Type *additionalArg) {
25182493
if (!F->empty())
25192494
F = preprocessForClone(F, mode);
25202495
llvm::ValueToValueMapTy VMap;
25212496
llvm::FunctionType *FTy = getFunctionTypeForClone(
25222497
F->getFunctionType(), mode, width, additionalArg, constant_args,
2523-
diffeReturnArg, returnValue, returnType);
2498+
diffeReturnArg, returnTape, returnPrimal, returnShadow);
25242499

25252500
for (BasicBlock &BB : *F) {
25262501
if (auto ri = dyn_cast<ReturnInst>(BB.getTerminator())) {

enzyme/Enzyme/FunctionUtils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ class PreProcessCache {
105105
llvm::ArrayRef<DIFFE_TYPE> constant_args,
106106
llvm::SmallPtrSetImpl<llvm::Value *> &constants,
107107
llvm::SmallPtrSetImpl<llvm::Value *> &nonconstant,
108-
llvm::SmallPtrSetImpl<llvm::Value *> &returnvals, ReturnType returnValue,
109-
DIFFE_TYPE returnType, const llvm::Twine &name,
108+
llvm::SmallPtrSetImpl<llvm::Value *> &returnvals, bool returnTape,
109+
bool returnPrimal, bool returnShadow, const llvm::Twine &name,
110110
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> *VMapO,
111111
bool diffeReturnArg, llvm::Type *additionalArg = nullptr);
112112

@@ -412,7 +412,7 @@ bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val);
412412
llvm::FunctionType *getFunctionTypeForClone(
413413
llvm::FunctionType *FTy, DerivativeMode mode, unsigned width,
414414
llvm::Type *additionalArg, llvm::ArrayRef<DIFFE_TYPE> constant_args,
415-
bool diffeReturnArg, ReturnType returnValue, DIFFE_TYPE returnType);
415+
bool diffeReturnArg, bool returnTape, bool returnPrimal, bool returnShadow);
416416

417417
/// Lower __enzyme_todense, returning if changed.
418418
bool LowerSparsification(llvm::Function *F, bool replaceAll = true);

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4352,16 +4352,6 @@ GradientUtils *GradientUtils::CreateFromClone(
43524352
++returnCount;
43534353
}
43544354

4355-
ReturnType returnValue;
4356-
if (returnCount == 0)
4357-
returnValue = ReturnType::Tape;
4358-
else if (returnCount == 1)
4359-
returnValue = ReturnType::TapeAndReturn;
4360-
else if (returnCount == 2)
4361-
returnValue = ReturnType::TapeAndTwoReturns;
4362-
else
4363-
llvm_unreachable("illegal number of elements in augmented return struct");
4364-
43654355
ValueToValueMapTy invertedPointers;
43664356
SmallPtrSet<Instruction *, 4> constants;
43674357
SmallPtrSet<Instruction *, 20> nonconstant;
@@ -4380,7 +4370,8 @@ GradientUtils *GradientUtils::CreateFromClone(
43804370
auto newFunc = Logic.PPC.CloneFunctionWithReturns(
43814371
DerivativeMode::ReverseModePrimal, width, oldFunc, invertedPointers,
43824372
constant_args, constant_values, nonconstant_values, returnvals,
4383-
/*returnValue*/ returnValue, retType, prefix, &originalToNew,
4373+
/*returnTape*/ true, /*returnPrimal*/ returnUsed,
4374+
/*returnShadow*/ shadowReturnUsed, prefix, &originalToNew,
43844375
/*diffeReturnArg*/ false, /*additionalArg*/ nullptr);
43854376

43864377
// Convert overwritten args from the input function to the preprocessed

0 commit comments

Comments
 (0)