@@ -1331,6 +1331,7 @@ void PullbackEmitter::visitSILInstruction(SILInstruction *inst) {
1331
1331
AllocStackInst *
1332
1332
PullbackEmitter::getArrayAdjointElementBuffer (SILValue arrayAdjoint,
1333
1333
int eltIndex, SILLocation loc) {
1334
+ auto &ctx = builder.getASTContext ();
1334
1335
auto arrayTanType = cast<StructType>(arrayAdjoint->getType ().getASTType ());
1335
1336
auto arrayType = arrayTanType->getParent ()->castTo <BoundGenericStructType>();
1336
1337
auto eltTanType = arrayType->getGenericArgs ().front ()->getCanonicalType ();
@@ -1340,7 +1341,19 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
1340
1341
auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct ();
1341
1342
auto subscriptLookup =
1342
1343
arrayTanStructDecl->lookupDirect (DeclBaseName::createSubscript ());
1343
- auto *subscriptDecl = cast<SubscriptDecl>(subscriptLookup.front ());
1344
+ SubscriptDecl *subscriptDecl = nullptr ;
1345
+ for (auto *candidate : subscriptLookup) {
1346
+ auto candidateModule = candidate->getModuleContext ();
1347
+ if (candidateModule->getName () == ctx.Id_Differentiation ||
1348
+ candidateModule->isStdlibModule ()) {
1349
+ assert (!subscriptDecl && " Multiple `Array.TangentVector.subscript`s" );
1350
+ subscriptDecl = cast<SubscriptDecl>(candidate);
1351
+ #ifdef NDEBUG
1352
+ break ;
1353
+ #endif
1354
+ }
1355
+ }
1356
+ assert (subscriptDecl && " No `Array.TangentVector.subscript`" );
1344
1357
auto *subscriptGetterDecl = subscriptDecl->getAccessor (AccessorKind::Get);
1345
1358
assert (subscriptGetterDecl && " No `Array.TangentVector.subscript` getter" );
1346
1359
SILOptFunctionBuilder fb (getContext ().getTransform ());
@@ -1352,7 +1365,6 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
1352
1365
subscriptGetterFn->getLoweredFunctionType ()->getSubstGenericSignature ();
1353
1366
// Apply `Array.TangentVector.subscript.getter` to get array element adjoint
1354
1367
// buffer.
1355
- auto &ctx = builder.getASTContext ();
1356
1368
// %index_literal = integer_literal $Builtin.IntXX, <index>
1357
1369
auto builtinIntType =
1358
1370
SILType::getPrimitiveObjectType (ctx.getIntDecl ()
0 commit comments