Skip to content

Commit cfb4251

Browse files
authored
Fix insertvalue issue (#2448)
* Fix insertvalue issue * fix * fix fwd ret
1 parent 76026d5 commit cfb4251

File tree

4 files changed

+311
-4
lines changed

4 files changed

+311
-4
lines changed

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ inline bool is_value_needed_in_reverse(
179179
}
180180
}
181181

182-
if (!TR.anyFloat(const_cast<Value *>(inst)))
182+
if (!TR.allFloat(const_cast<Value *>(inst)))
183183
if (auto IVI = dyn_cast<Instruction>(user)) {
184184
bool inserted = false;
185185
if (auto II = dyn_cast<InsertValueInst>(IVI))
@@ -215,10 +215,32 @@ inline bool is_value_needed_in_reverse(
215215
}
216216

217217
bool partial = false;
218-
if (!gutils->isConstantValue(const_cast<Instruction *>(cur))) {
219-
partial = is_value_needed_in_reverse<QueryType::Shadow>(
220-
gutils, user, mode, seen, oldUnreachable);
218+
if (auto UI = dyn_cast<Instruction>(u)) {
219+
if (!gutils->isConstantValue(
220+
const_cast<Instruction *>(cur))) {
221+
bool recursiveUse = false;
222+
if (is_use_directly_needed_in_reverse(
223+
gutils, cur, mode, UI, oldUnreachable,
224+
QueryType::Shadow, &recursiveUse)) {
225+
partial = true;
226+
} else if (recursiveUse && !OneLevel) {
227+
partial = is_value_needed_in_reverse<QueryType::Shadow>(
228+
gutils, UI, mode, seen, oldUnreachable);
229+
}
230+
} else if (VT == QueryType::Shadow) {
231+
bool recursiveUse = false;
232+
if (is_use_directly_needed_in_reverse(
233+
gutils, cur, mode, UI, oldUnreachable,
234+
QueryType::ShadowByConstPrimal, &recursiveUse)) {
235+
partial = true;
236+
} else if (recursiveUse && !OneLevel) {
237+
partial = is_value_needed_in_reverse<
238+
QueryType::ShadowByConstPrimal>(gutils, UI, mode,
239+
seen, oldUnreachable);
240+
}
241+
}
221242
}
243+
222244
if (partial) {
223245

224246
if (EnzymePrintDiffUse)

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6099,6 +6099,35 @@ size_t skippedBytes(SmallSet<size_t, 8> &offs, Type *T, const DataLayout &DL,
60996099
return prevOff;
61006100
}
61016101

6102+
bool TypeResults::allFloat(Value *val) const {
6103+
assert(val);
6104+
assert(val->getType());
6105+
auto q = query(val);
6106+
auto dt = q[{-1}];
6107+
if (dt != BaseType::Anything && dt != BaseType::Unknown)
6108+
return dt.isFloat();
6109+
6110+
if (val->getType()->isTokenTy() || val->getType()->isVoidTy())
6111+
return false;
6112+
auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout();
6113+
SmallSet<size_t, 8> offs;
6114+
size_t ObjSize = skippedBytes(offs, val->getType(), dl);
6115+
6116+
for (size_t i = 0; i < ObjSize;) {
6117+
dt = q[{(int)i}];
6118+
if (auto FT = dt.isFloat()) {
6119+
i += (dl.getTypeSizeInBits(FT) + 7) / 8;
6120+
continue;
6121+
}
6122+
if (offs.count(i)) {
6123+
i++;
6124+
continue;
6125+
}
6126+
return false;
6127+
}
6128+
return true;
6129+
}
6130+
61026131
bool TypeResults::anyFloat(Value *val, bool anythingIsFloat) const {
61036132
assert(val);
61046133
assert(val->getType());

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ class TypeResults {
205205
// be considered a float.
206206
bool anyFloat(llvm::Value *val, bool anythingIsFloat = true) const;
207207

208+
/// Whether all of the top level register is known to contain float data
209+
bool allFloat(llvm::Value *val) const;
210+
208211
/// Whether any part of the top level register can contain a pointer
209212
/// e.g. { i64, i8* } can contain a pointer, but { i64, float } would not.
210213
// Of course, here we compute with type analysis rather than llvm type

0 commit comments

Comments
 (0)