@@ -3166,7 +3166,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
3166
3166
}
3167
3167
3168
3168
void createTerminator (DiffeGradientUtils *gutils, BasicBlock *oBB,
3169
- DIFFE_TYPE retType, ReturnType retVal) {
3169
+ DIFFE_TYPE retType, bool returnPrimal,
3170
+ bool returnShadow) {
3170
3171
TypeResults &TR = gutils->TR ;
3171
3172
ReturnInst *inst = dyn_cast<ReturnInst>(oBB->getTerminator ());
3172
3173
// In forward mode we only need to update the return value
@@ -3212,75 +3213,45 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,
3212
3213
}
3213
3214
}
3214
3215
3215
- switch (retVal) {
3216
- case ReturnType::Return: {
3217
- auto ret = inst->getOperand (0 );
3218
-
3219
- Type *rt = ret->getType ();
3220
- while (auto AT = dyn_cast<ArrayType>(rt))
3221
- rt = AT->getElementType ();
3222
- bool floatLike = rt->isFPOrFPVectorTy ();
3223
-
3224
- if (retType == DIFFE_TYPE::CONSTANT) {
3225
- toret = gutils->getNewFromOriginal (ret);
3226
- } else if (!floatLike &&
3227
- TR.getReturnAnalysis ().Inner0 ().isPossiblePointer ()) {
3228
- toret = invertedPtr ? invertedPtr : gutils->invertPointerM (ret, nBuilder);
3229
- } else if (!gutils->isConstantValue (ret)) {
3230
- assert (!invertedPtr);
3231
- toret = gutils->diffe (ret, nBuilder);
3232
- } else {
3233
- toret = invertedPtr
3234
- ? invertedPtr
3235
- : gutils->invertPointerM (ret, nBuilder, /* nullInit*/ true );
3236
- }
3216
+ Value *primal = nullptr ;
3217
+ Value *shadow = nullptr ;
3237
3218
3238
- break ;
3219
+ if (returnPrimal) {
3220
+ auto ret = inst->getOperand (0 );
3221
+ primal = gutils->getNewFromOriginal (ret);
3239
3222
}
3240
- case ReturnType::TwoReturns: {
3241
- if (retType == DIFFE_TYPE::CONSTANT)
3242
- assert (false && " Invalid return type" );
3223
+ if (returnShadow) {
3243
3224
auto ret = inst->getOperand (0 );
3244
-
3245
3225
Type *rt = ret->getType ();
3246
3226
while (auto AT = dyn_cast<ArrayType>(rt))
3247
3227
rt = AT->getElementType ();
3248
3228
bool floatLike = rt->isFPOrFPVectorTy ();
3249
3229
3250
- toret =
3251
- nBuilder.CreateInsertValue (toret, gutils->getNewFromOriginal (ret), 0 );
3252
-
3253
3230
if (!floatLike && TR.getReturnAnalysis ().Inner0 ().isPossiblePointer ()) {
3254
- toret = nBuilder.CreateInsertValue (
3255
- toret,
3256
- invertedPtr ? invertedPtr : gutils->invertPointerM (ret, nBuilder), 1 );
3231
+ shadow =
3232
+ invertedPtr ? invertedPtr : gutils->invertPointerM (ret, nBuilder);
3257
3233
} else if (!gutils->isConstantValue (ret)) {
3258
3234
assert (!invertedPtr);
3259
- toret =
3260
- nBuilder.CreateInsertValue (toret, gutils->diffe (ret, nBuilder), 1 );
3235
+ shadow = gutils->diffe (ret, nBuilder);
3261
3236
} else {
3262
- toret = nBuilder.CreateInsertValue (
3263
- toret,
3264
- invertedPtr
3265
- ? invertedPtr
3266
- : gutils->invertPointerM (ret, nBuilder, /* nullInit*/ true ),
3267
- 1 );
3237
+ shadow = invertedPtr
3238
+ ? invertedPtr
3239
+ : gutils->invertPointerM (ret, nBuilder, /* nullInit*/ true );
3268
3240
}
3269
- break ;
3270
3241
}
3271
- case ReturnType::Void: {
3242
+
3243
+ if (primal && shadow) {
3244
+ toret = nBuilder.CreateInsertValue (toret, primal, 0 );
3245
+ toret = nBuilder.CreateInsertValue (toret, shadow, 1 );
3246
+ } else if (primal) {
3247
+ toret = primal;
3248
+ } else if (shadow) {
3249
+ toret = shadow;
3250
+ } else {
3272
3251
gutils->erase (gutils->getNewFromOriginal (inst));
3273
3252
nBuilder.CreateRetVoid ();
3274
3253
return ;
3275
3254
}
3276
- default : {
3277
- llvm::errs () << " Invalid return type: " << to_string (retVal)
3278
- << " for function: \n "
3279
- << gutils->newFunc << " \n " ;
3280
- assert (false && " Invalid return type for function" );
3281
- return ;
3282
- }
3283
- }
3284
3255
3285
3256
gutils->erase (newInst);
3286
3257
nBuilder.CreateRet (toret);
@@ -4234,19 +4205,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
4234
4205
assert (augmenteddata->constant_args == key.constant_args );
4235
4206
}
4236
4207
4237
- ReturnType retVal =
4238
- key.returnUsed ? (key.shadowReturnUsed ? ReturnType::ArgsWithTwoReturns
4239
- : ReturnType::ArgsWithReturn)
4240
- : (key.shadowReturnUsed ? ReturnType::ArgsWithReturn
4241
- : ReturnType::Args);
4242
-
4243
4208
bool diffeReturnArg = key.retType == DIFFE_TYPE::OUT_DIFF;
4244
4209
4245
4210
DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone (
4246
4211
*this , key.mode , key.runtimeActivity , key.strongZero , key.width ,
4247
4212
key.todiff , TLI, TA, oldTypeInfo, key.retType ,
4248
4213
augmenteddata ? augmenteddata->shadowReturnUsed : key.shadowReturnUsed ,
4249
- diffeReturnArg, key.constant_args , retVal, key.additionalType , omp);
4214
+ diffeReturnArg, key.constant_args , /* returnTape*/ false , key.returnUsed ,
4215
+ key.additionalType , omp);
4250
4216
4251
4217
gutils->AtomicAdd = key.AtomicAdd ;
4252
4218
gutils->FreeMemory = key.freeMemory ;
@@ -4904,17 +4870,13 @@ Function *EnzymeLogic::CreateForwardDiff(
4904
4870
4905
4871
bool retActive = retType != DIFFE_TYPE::CONSTANT;
4906
4872
4907
- ReturnType retVal =
4908
- returnUsed ? (retActive ? ReturnType::TwoReturns : ReturnType::Return)
4909
- : (retActive ? ReturnType::Return : ReturnType::Void);
4910
-
4911
4873
bool diffeReturnArg = false ;
4912
4874
4913
4875
DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone (
4914
4876
*this , mode, runtimeActivity, strongZero, width, todiff, TLI, TA,
4915
4877
oldTypeInfo, retType,
4916
- /* shadowReturn*/ retActive, diffeReturnArg, constant_args, retVal,
4917
- additionalArg, omp);
4878
+ /* shadowReturn*/ retActive, diffeReturnArg, constant_args,
4879
+ /* returnTape */ false , returnUsed, additionalArg, omp);
4918
4880
4919
4881
insert_or_assign2<ForwardCacheKey, Function *>(ForwardCachedFunctions, tup,
4920
4882
gutils->newFunc );
@@ -5092,7 +5054,7 @@ Function *EnzymeLogic::CreateForwardDiff(
5092
5054
maker->visit (&*it);
5093
5055
}
5094
5056
5095
- createTerminator (gutils, &oBB, retType, retVal );
5057
+ createTerminator (gutils, &oBB, retType, returnUsed, retActive );
5096
5058
}
5097
5059
5098
5060
if (mode == DerivativeMode::ForwardModeSplit && augmenteddata)
0 commit comments