-
Notifications
You must be signed in to change notification settings - Fork 21
feat: propagate bounds in whileop #1453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
68b2367
to
37e44ff
Compare
auto step = info.getConstantStep().value(); | ||
|
||
// currently only removes remainder of induction variable | ||
whileBody.walk([&](stablehlo::RemOp remOp) -> WalkResult { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead could we just look at all users of the induction var (that likely will be faster than a walk of the body)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EnzymeJAX Benchmarks
Benchmark suite | Current: e358839 | Previous: 54edf38 | Ratio |
---|---|---|---|
scatter_sum / JaX / cpu / Primal |
0.000004446296999958577 s |
0.000004485886999464128 s |
0.99 |
scatter_sum / JaXPipe / cpu / Primal |
0.000004340780000347877 s |
0.000004439690998697188 s |
0.98 |
scatter_sum / JaX / tpu / Primal |
0.0001532494204999 s |
0.0001577955517001 s |
0.97 |
scatter_sum / JaXPipe / tpu / Primal |
0.0001514654375001 s |
0.0001410659988992 s |
1.07 |
This comment was automatically generated by workflow using github-action-benchmark.
efaec18
to
fc223de
Compare
} | ||
|
||
// Initialize bounds map with induction variable bounds | ||
DenseMap<Value, Bounds> boundsMap; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
min/max only applies if step >= 0, if negative this needs to be changed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the induction var is initialized as
APInt minBound(bitWidth, std::min(start, limit), true);
APInt maxBound(bitWidth, std::max(start, limit), true);
boundsMap[inductionVar] = Bounds(minBound, maxBound);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, is min/max inclusive or exclusive then. If step > 0, range is [start, limit). If negative it's (limit, start]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inclusive in the subsequent usage, needs to be fixed here
return false; | ||
}; | ||
|
||
bool rewriteCompareOp(PatternRewriter &rewriter, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems somewhat (but not wholly) similar to:
Enzyme-JAX/src/enzyme_ad/jax/Passes/AffineCFG.cpp
Line 5941 in 9861ed8
bool valueCmp(Cmp cmp, Value bval, ValueOrInt val) { |
DenseMap<Value, Bounds> boundsMap; | ||
if (step > 0) { | ||
APInt minBound(bitWidth, start, true); | ||
APInt maxBound(bitWidth, limit - 1, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this isn't necessarily correct, since start == limit is a valid case [where the loop body isn't executed]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
though I suppose if the body is never executed, what is done in the after region doesn't matter anyways
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a check for loop count >= 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not seeing where since if step > limit - start, loop count is still zero [and step only checked for sign]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit above
if (!info.isValid() || !info.isConstant() ||
info.getConstantNumIters() <= 0)
return failure();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ok
600c987
to
e116b87
Compare
f0beb10
to
e358839
Compare
No description provided.