Skip to content

Commit c0bade1

Browse files
committed
Add Enzyme sum derivatives
1 parent 76e2972 commit c0bade1

File tree

2 files changed

+173
-0
lines changed

2 files changed

+173
-0
lines changed

ext/EnzymeCoreExt.jl

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ else
1212
using ..EnzymeCore
1313
using ..EnzymeCore.EnzymeRules
1414
end
15+
using GPUArrays
1516

1617
function EnzymeCore.EnzymeRules.inactive(::typeof(CUDA.CUBLAS.handle))
1718
return nothing
@@ -489,5 +490,163 @@ function EnzymeCore.EnzymeRules.noalias(::Type{CT}, ::UndefInitializer, args...)
489490
return nothing
490491
end
491492

493+
function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays.mapreducedim!)},
494+
::Type{RT},
495+
f::EnzymeCore.Const{typeof(Base.identity)},
496+
op::EnzymeCore.Const{typeof(Base.add_sum)},
497+
R::EnzymeCore.Annotation{<:AnyCuArray{T}}, A; init) where {RT, T}
498+
if R isa Const || R isa Duplicated || R isa BatchDuplicated
499+
ofn.val(f.val, op.val, R.val, A.val; init)
500+
end
501+
502+
if A isa Duplicated || A isa DuplicatedNoNeed
503+
if A isa Const
504+
Base.fill!(R.dval, zero(T))
505+
else
506+
ofn.val(f.val, op.val, R.dval, A.dval)
507+
end
508+
elseif R isa BatchDuplicated || R isa BatchDuplicatedNoNeed
509+
ntuple(Val(EnzymeRules.batch_width(R))) do i
510+
Base.@_inline_meta
511+
if A isa Const
512+
Base.fill!(R.dval[i], zero(T))
513+
else
514+
ofn.val(f.val, op.val, R.dval[i], A.dval[i])
515+
end
516+
nothing
517+
end
518+
end
519+
520+
if RT <: Duplicated
521+
return R
522+
elseif RT <: Const
523+
return R.val
524+
elseif RT <: DuplicatedNoNeed
525+
return R.dval
526+
elseif RT <: BatchDuplicated
527+
return R
528+
else
529+
return R.dval
530+
end
531+
end
532+
533+
534+
function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(GPUArrays.mapreducedim!)},
535+
::Type{RT},
536+
f::EnzymeCore.Const{typeof(Base.identity)},
537+
op::EnzymeCore.Const{typeof(Base.add_sum)},
538+
R::EnzymeCore.Annotation{<:AnyCuArray{T}}, A; init) where {RT, T<:AbstractFloat}
539+
if A isa Const || A isa Duplicated || A isa BatchDuplicated
540+
ofn.val(f.val, op.val, R.val, A.val)
541+
end
542+
543+
primal = if EnzymeRules.needs_primal(config)
544+
R.val
545+
else
546+
nothing
547+
end
548+
549+
shadow = if EnzymeRules.needs_shadow(config)
550+
R.dval
551+
else
552+
nothing
553+
end
554+
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
555+
end
556+
557+
function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(GPUArrays.mapreducedim!)},
558+
::Type{RT},
559+
tape,
560+
f::EnzymeCore.Const{typeof(Base.identity)},
561+
op::EnzymeCore.Const{typeof(Base.add_sum)},
562+
R::EnzymeCore.Annotation{<:AnyCuArray{T}}, A; init) where {RT, T<:AbstractFloat}
563+
564+
if !(A isa Const) && !(R isa Const)
565+
if A isa Duplicated || A isa DuplicatedNoNeed
566+
A.dval .+= R.dval
567+
Base.fill!(R.dval, zero(T))
568+
elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed
569+
ntuple(Val(EnzymeRules.batch_width(A))) do i
570+
Base.@_inline_meta
571+
A.dval[i] .+= R.dval[i]
572+
Base.fill!(R.dval[i], zero(T))
573+
nothing
574+
end
575+
end
576+
end
577+
578+
return (nothing, nothing, nothing, nothing)
579+
end
580+
581+
function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays._mapreduce)},
582+
::Type{RT},
583+
f::EnzymeCore.Const{typeof(Base.identity)},
584+
op::EnzymeCore.Const{typeof(Base.add_sum)},
585+
A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T, D}
586+
if RT <: Const
587+
ofn.val(f.val, op.val, A.val; dims, init)
588+
elseif RT <: Duplicated
589+
(
590+
ofn.val(f.val, op.val, A.val; dims, init),
591+
ofn.val(f.val, op.val, A.dval; dims, init)
592+
)
593+
elseif RT <: DuplicatedNoNeed
594+
ofn.val(f.val, op.val, A.dval; dims, init)
595+
elseif RT <: BatchDuplicated
596+
(
597+
ofn.val(f.val, op.val, A.val; dims, init),
598+
ntuple(Val(EnzymeRules.batch_width(RT))) do i
599+
Base.@_inline_meta
600+
ofn.val(f.val, op.val, A.dval[i]; dims, init)
601+
end
602+
)
603+
else
604+
@assert RT <: BatchDuplicatedNoNeed
605+
ntuple(Val(EnzymeRules.batch_width(RT))) do i
606+
Base.@_inline_meta
607+
ofn.val(f.val, op.val, A.dval[i]; dims, init)
608+
end
609+
end
610+
end
611+
612+
function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(GPUArrays._mapreduce)},
613+
::Type{Active{RT}},
614+
f::EnzymeCore.Const{typeof(Base.identity)},
615+
op::EnzymeCore.Const{typeof(Base.add_sum)},
616+
A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T<:AbstractFloat, D}
617+
primal = if EnzymeRules.needs_primal(config)
618+
ofn.val(f.val, op.val, A.val; dims, init)
619+
else
620+
nothing
621+
end
622+
623+
shadow = if EnzymeRules.needs_shadow(config)
624+
A.dval
625+
else
626+
nothing
627+
end
628+
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
629+
end
630+
631+
function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(GPUArrays._mapreduce)},
632+
dres::Active{RT},
633+
tape,
634+
f::EnzymeCore.Const{typeof(Base.identity)},
635+
op::EnzymeCore.Const{typeof(Base.add_sum)},
636+
A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T<:AbstractFloat, D}
637+
638+
if A isa Duplicated || A isa DuplicatedNoNeed
639+
A.dval .+= dres.val
640+
elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed
641+
ntuple(Val(EnzymeRules.batch_width(A))) do i
642+
Base.@_inline_meta
643+
A.dval[i] .+= dres.val
644+
nothing
645+
end
646+
end
647+
648+
return (nothing, nothing, nothing, nothing)
649+
end
650+
492651
end # module
493652

test/extensions/enzyme.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,20 @@ firstsum(x, y) = first(x .+ y)
103103
#@test res[2] ≈ 1.2
104104
end
105105

106+
@testset "Forward sum" begin
107+
x = CuArray([1.0, 2.0, 3.0, 4.0])
108+
dx = CuArray([100., 300.0, 500.0, 700.0])
109+
res = Enzyme.autodiff(Forward, sum, Duplicated(x, dx))
110+
@test res 100+300+500+700
111+
end
112+
113+
@testset "Reverse sum" begin
114+
x = CuArray([1.0, 2.0, 3.0, 4.0])
115+
dx = CuArray([0., 0.0, 0.0, 0.0])
116+
Enzyme.autodiff(Reverse, sum, Duplicated(x, dx))
117+
@test all(dx .≈ 1.0)
118+
end
119+
106120
# TODO once reverse kernels are in
107121
# function togpu(x)
108122
# x = CuArray(x)

0 commit comments

Comments
 (0)