-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Enzyme sum derivatives #2471
Conversation
098f13d
to
c0bade1
Compare
@maleadt mind giving this a review? All tests pass |
I can't review this; I don't have any Enzyme.jl experience, so the code looks foreign to me. |
With this PR, Enzyme main (44febc), GPUCompiler 0.27.1 and Julia 1.10.3 I get: using Enzyme, CUDA
Enzyme.API.runtimeActivity!(true)
f(x, y) = sum(x .+ y)
x = CuArray(rand(5))
y = CuArray(rand(5))
dx = CuArray([1.0, 0.0, 0.0, 0.0, 0.0])
autodiff(Forward, f, Duplicated, Duplicated(x, dx), Const(y))
And for reverse mode: autodiff(Reverse, f, Active, Duplicated(x, dx), Const(y))
|
yeah @jgreener64 this does so for arrays, not for broadcasted yet (so you'll need to actually create the .+ result |
@maleadt now with CI happy again, can this get a merge? |
No description provided.