Merge pull request #90 from DrChainsaw/funcmemo
Change examples and test to use explicit gradients
DrChainsaw authored Aug 7, 2023
2 parents 7d1150e + 36d24bb commit f502e5f
Showing 10 changed files with 133 additions and 96 deletions.
6 changes: 5 additions & 1 deletion Project.toml
@@ -1,6 +1,6 @@
name = "NaiveNASflux"
uuid = "85610aed-7d32-5e57-bb50-4c2e1c9e7997"
version = "2.0.9"
version = "2.0.10"

ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -9,6 +9,8 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NaiveNASlib = "bd45eb3e-47ce-54bd-9eaf-e86c5f900853"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Expand All @@ -20,6 +22,8 @@ Flux = "0.13, 0.14"
Functors = "0.2, 0.3, 0.4"
JuMP = "0.19, 0.20, 0.21, 0.22, 0.23, 1"
NaiveNASlib = "2.0.6"
Optimisers = "0.2"
PrecompileTools = "1"
Reexport = "0.2.0, 1.0"
Setfield = "0.3.4, 0.5, 0.6, 0.7, 0.8, 1.0"
julia = "1"
5 changes: 4 additions & 1 deletion src/NaiveNASflux.jl
Expand Up @@ -6,6 +6,7 @@ using NaiveNASlib.Extend, NaiveNASlib.Advanced
import Flux
using Flux: Dense, Conv, ConvTranspose, CrossCor, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
MaxPool, MeanPool, Dropout, AlphaDropout, GlobalMaxPool, GlobalMeanPool, cpu
import Optimisers
import Functors
using Functors: @functor
using Statistics
Expand Down Expand Up @@ -39,4 +40,6 @@ include("neuronutility.jl")
# Stuff to integrate with Zygote

end # module

end # module
8 changes: 4 additions & 4 deletions src/chainrules.jl
Expand Up @@ -11,8 +11,8 @@ ChainRulesCore.@non_differentiable mutate(args...)
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, m::MutableLayer, args...)
res, back = rrule_via_ad(config, m.layer, args...)
function MutableLayer_back(Δ)
δlayer, δargs = back(Δ)
Tangent{MutableLayer}(layer=δlayer), δargs
δs = back(Δ)
Tangent{MutableLayer}(layer=δs[1]), δs[2:end]...
return res, MutableLayer_back
Expand All @@ -23,8 +23,8 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, m::LazyMutab
res, back = rrule_via_ad(config, m.mutable, args...)
function LazyMutable_back(Δ)
δmutable, δargs = back(Δ)
Tangent{LazyMutable}(mutable=δmutable), δargs
δs = back(Δ)
Tangent{LazyMutable}(mutable=δs[1]), δs[2:end]...
return res, LazyMutable_back
1 change: 1 addition & 0 deletions src/mutable.jl
Expand Up @@ -11,6 +11,7 @@ Base.Broadcast.broadcastable(m::AbstractMutableComp) = Ref(m)

# Not possible in julia <= 1.1. See #14919
# (m::AbstractMutableComp)(x...) = layer(m)(x...)
layer(m::AbstractMutableComp) = layer(wrapped(m))
layertype(m::AbstractMutableComp) = layertype(layer(m))

NaiveNASlib.nin(m::AbstractMutableComp) = nin(wrapped(m))
80 changes: 44 additions & 36 deletions src/neuronutility.jl
ActivationContribution{L,M} <: AbstractMutableComp
ActivationContribution{L,C,M} <: AbstractMutableComp
ActivationContribution(l, method)
Calculate neuron utility based on activations and gradients using `method`.
Calculate neuron utility based on activations and gradients using `method`.
Designed to be used as `layerfun` argument to [`fluxvertex`](@ref).
Can be a performance bottleneck in cases with large activations. Use [`NeuronUtilityEvery`](@ref) to mitigate.
Expand All @@ -13,58 +15,63 @@ Short summary is that the first order taylor approximation of the optimization p
boils down to: "the ones which minimize `abs(gradient * activation)`" (assuming parameter independence).
struct ActivationContribution{L,M} <: AbstractMutableComp
struct ActivationContribution{L,C,M} <: AbstractMutableComp
contribution::Base.RefValue{Any} # Type of activation not known yet :( Also leave some room for experimenting with things like storing the metric on the GPU
# We use eps(Float32) here because we don't want parameters from new layers to have:
# 1) higher utility than parameters from existing layers
# 2) zero utility since that will often make the optimizer remove them completely
# eps(Float32) is typically smaller than the optimizer tolerance, but NaiveNASlib tries to rescale
ActivationContribution(l::AbstractMutableComp, method = Ewma(0.05f0)) = ActivationContribution(l, Ref{Any}(fill(eps(Float32), nout(l))), method)
ActivationContribution(l, method = Ewma(0.05f0)) = ActivationContribution(l, Ref{Any}(missing), method)

layer(m::ActivationContribution) = layer(m.layer)
layertype(m::ActivationContribution) = layertype(m.layer)
wrapped(m::ActivationContribution) = m.layer
ActivationContribution(l::AbstractMutableComp, method = Ewma(0.05f0)) = ActivationContribution(l, fill(eps(Float32), nout(l)), method)
ActivationContribution(l, method = Ewma(0.05f0)) = ActivationContribution(l, Float32[], method)

Flux.trainable(m::ActivationContribution) = (;layer = Flux.trainable(wrapped(m)))
@functor ActivationContribution

function Functors.functor(t::Type{<:ActivationContribution}, m::ActivationContribution)
_functor(t, m.layer, m.contribution[], m.method)

function Functors.functor(t::Type{<:ActivationContribution}, m)
_functor(t, m.layer, m.contribution, m.method)

function _functor(::Type{<:ActivationContribution}, layer, contribution, method)
return (;layer, contribution, method), function(y)
ActivationContribution(y.layer, Ref{Any}(y.contribution), y.method)

act = wrapped(m)(x...)
wrapped(m::ActivationContribution) = m.layer

return Flux.Zygote.hook(act) do grad
grad === nothing && return grad
m.contribution[] = m.method(m.contribution[], act, grad)
return grad
# We do train contribution in some sense, but we don't want Flux to do it
# We could create a "fake" gradient in the rrule and let the optimizer rule update it for us
# (rather than using our own Ewma), but it is probably not desirable to mix the model parameter update
# strategy with the activation contribution strategy.
Flux.trainable(m::ActivationContribution) = (;layer = Flux.trainable(m.layer))

# Just passthrough when not taking gradients.
(m::ActivationContribution)(x...) = wrapped(m)(x...)

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, m::T, x...) where T <:ActivationContribution
act, back = rrule_via_ad(config, wrapped(m), x...)

function ActivationContribution_back(Δ)
if length(m.contribution) === 0
newcontribution = m.method(missing, act, Δ)
resize!(m.contribution, length(newcontribution))
copyto!(m.contribution, newcontribution)
copyto!(m.contribution, m.method(m.contribution, act, Δ))

δs = back(Δ)
Tangent{T}(layer=δs[1]), δs[2:end]...
return act, ActivationContribution_back

actdim(nd::Integer) = nd - 1

function NaiveNASlib.Δsize!(m::ActivationContribution, inputs::AbstractVector, outputs::AbstractVector; kwargs...)
if m.contribution[] !== missing
if length(m.contribution) !== 0
# This tends to happen when we are measuring contribution for a concatenation and we have added an extra input edge
# TODO: Try to find another fix, perhaps we need to ensure that nout(v) if v wraps an ActivationContribution always return
# the length of m.contribution
outputs[outputs .> length(m.contribution[])] .= -1
outputs[outputs .> length(m.contribution)] .= -1
newcontribution = select(m.contribution, 1 => outputs; newfun = (args...) -> eps(eltype(m.contribution)))
resize!(m.contribution, length(newcontribution))
copyto!(m.contribution, newcontribution)
# no need to select anything
m.contribution[] = select(m.contribution[], 1 => outputs; newfun = (args...) -> eps(eltype(m.contribution[])))
NaiveNASlib.Δsize!(wrapped(m), inputs, outputs; kwargs...)

Expand Down Expand Up @@ -100,7 +107,8 @@ function neuronutility(lm::LazyMutable)
neuronutility(m::ActivationContribution) = m.contribution[]
# Return missing to maintain API since previous versions used missing as sentinel value instead of empty vector
neuronutility(m::ActivationContribution) = isempty(m.contribution) ? missing : m.contribution
neuronutility(l) = neuronutility(layertype(l), l)

# Default: mean of abs of weights + bias. Not a very good metric, but should be better than random
Expand Down Expand Up @@ -134,7 +142,7 @@ neuronutility_safe(::MutationSizeTrait, v) = clean_values(cpu(neuronutility(v)))
neuronutility_safe(m::AbstractMutableComp) = clean_values(cpu(neuronutility(m)))

clean_values(::Missing) = 1
clean_values(a::AbstractArray) = replace(a, NaN => -100, Inf => -100, -Inf => -100)
clean_values(a::AbstractArray) = length(a) === 0 ? 1 : replace(a, NaN => -100, Inf => -100, -Inf => -100)

neuronutilitytaylor(currval, act, grad)
Expand Down
@@ -0,0 +1,28 @@
using PrecompileTools

@setup_workload begin
iv1 = denseinputvertex("iv1", 1)
v1 = fluxvertex("v1", Dense(nout(iv1) => 1), iv1)
v2 = concat("v2", v1, v1; layerfun=ActivationContribution)
v3 = concat("v3", v2,v1,iv1)
v4 = "v4" >> v3 + v3
v5 = "v5" >> v4 + v4 + v4
v6 = fluxvertex("v6", Dense(nout(v5) => 1), v5; layerfun = ActivationContribution LazyMutable)

g1 = CompGraph(iv1, v6)
x1 = ones(Float32, 1, 1)

@compile_workload begin
iv1 = denseinputvertex("iv1", 1)
fluxvertex("v1", Dense(nout(iv1) => 1), iv1)

Flux.@code_adjoint g1(x1)
#Optimisers.setup(Optimisers.Descent(0.1f0), g1)
#Flux.gradient((g,x) -> sum(g(x)), g1, x1)

Δnout!(v3 => relaxed(2))
67 changes: 31 additions & 36 deletions test/chainrules.jl
Expand Up @@ -11,19 +11,13 @@

import Optimisers
function with_explicit_grads(f)
NaiveNASlib.enable_explicit_gradients[] = true
NaiveNASlib.enable_explicit_gradients[] = false

teststructs(g::CompGraph, res, exp; seen=Base.IdSet()) = foreach(enumerate(outputs(g))) do (i, vo)
teststructs(g::CompGraph{<:Any, <:Tuple}, res, exp; seen=Base.IdSet()) = foreach(enumerate(outputs(g))) do (i, vo)
teststructs(vo, seen, res.outputs[i] ,exp)

teststructs(g::CompGraph{<:Any, <:AbstractVertex}, res, exp; seen=Base.IdSet()) = teststructs(g.outputs, seen, res.outputs ,exp)

function teststructs(v::AbstractVertex, seen, res, exp)
v in seen && return
push!(seen, v)
Expand All @@ -33,11 +27,13 @@
function _teststructs(::NaiveNASflux.InputShapeVertex, seen, res, exp, name) end

function _teststructs(v::AbstractVertex, seen, res::RT, exp, name) where RT
@testset "Check structure for $(name) of type $(typeof(v))" begin
@test hasfield(RT, :base)
if hasfield(RT, :base)
_teststructs(base(v), seen, res.base, exp, name)
if layertype(v) isa NaiveNASflux.FluxParLayer
@testset "Check structure for $(name) of type $(typeof(v))" begin
@test hasfield(RT, :base)
if hasfield(RT, :base)
_teststructs(base(v), seen, res.base, exp, name)
function _teststructs(v::CompVertex, seen, res::RT, exp, name) where RT
Expand All @@ -46,6 +42,7 @@
_teststructs(v.computation, res.computation, exp, name)
foreach(enumerate(inputs(v))) do (i, vi)
isnothing(res.inputs) && return
teststructs(vi, seen, res.inputs[i], exp)
Expand Down Expand Up @@ -111,28 +108,26 @@
@test exp[p] == res[p]

with_explicit_grads() do
@test gradient(sum graph, indata) == gradient(sum chain, indata)

expex = Flux.gradient(c -> sum(c(indata)), chain)
resex = Flux.gradient(g -> sum(g(indata)), graph)
teststructs(graph, resex..., expex[1].layers)

@testset "Optimisers" begin
graphstate = Optimisers.setup(Optimisers.Adam(), graph)
chainstate = Optimisers.setup(Optimisers.Adam(), chain)
@testset "Setup state" begin
teststructs(graph, graphstate, chainstate.layers)
# TODO: Why deepcopy needed? fmap(copy, graph) does not seem to work?
graphstate, newgraph = Optimisers.update(graphstate, deepcopy(graph), resex...)
chainstate, newchain = Optimisers.update(chainstate, chain, expex...)
@testset "New state" begin
teststructs(newgraph, graphstate, chainstate.layers)
@testset "New model" begin
teststructs(newgraph, Optimisers.trainable(newgraph), Optimisers.trainable(newchain).layers)
@test gradient(sum graph, indata) == gradient(sum chain, indata)

expex = Flux.gradient(c -> sum(c(indata)), chain)
resex = Flux.gradient(g -> sum(g(indata)), graph)
teststructs(graph, resex..., expex[1].layers)

@testset "Optimisers" begin
graphstate = Optimisers.setup(Optimisers.Adam(), graph)
chainstate = Optimisers.setup(Optimisers.Adam(), chain)
@testset "Setup state" begin
teststructs(graph, graphstate, chainstate.layers)
# TODO: Why deepcopy needed? fmap(copy, graph) does not seem to work?
graphstate, newgraph = Optimisers.update(graphstate, deepcopy(graph), resex...)
chainstate, newchain = Optimisers.update(chainstate, chain, expex...)
@testset "New state" begin
teststructs(newgraph, graphstate, chainstate.layers)
@testset "New model" begin
teststructs(newgraph, Optimisers.trainable(newgraph), Optimisers.trainable(newchain).layers)

Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,18 @@ layer2 = densevertex(layer1, 1, sigmoid)
original = CompGraph(invertex, layer2)

## Training params, nothing to see here
opt = Adam(0.1)
loss(g) = (x, y) -> mse(g(x), y)
loss(f, x, y) = mse(f(x), y)

## Training data: xor truth table: y = xor(x) just so we don't need to download a dataset.
x = Float32[0 0 1 1;
0 1 0 1]
y = Float32[0 1 1 0]

trainiter = Iterators.repeated((x,y), niters)

## Train the model
train!(loss(original), params(original), Iterators.repeated((x,y), niters), opt)
@test loss(original)(x, y) < 0.001
train!(loss, original, trainiter, Flux.setup(Adam(0.1), original))
@test loss(original, x, y) < 0.001

# With that out of the way, lets try three different ways to prune the hidden layer (vertex nr 2 in the graph).
# To make examples easier to compare, lets decide up front that we want to remove half of the hidden layer neurons
Expand All @@ -61,13 +62,13 @@ pruned_random = deepcopy(original)
Δnout!(v -> rand(nout(v)), pruned_random[2] => -nprune)

# Free lunch anyone?
@test loss(pruned_most)(x, y) >
loss(pruned_random)(x, y) >
loss(pruned_least)(x, y) >=
loss(original)(x, y)
@test loss(pruned_most, x, y) >
loss(pruned_random, x, y) >
loss(pruned_least, x, y) >=
loss(original, x, y)

# The metric calculated by [`ActivationContribution`](@ref) is actually quite good in this case.
@test loss(pruned_least)(x, y) loss(original)(x, y) atol = 1e-5
@test loss(pruned_least, x, y) loss(original, x, y) atol = 1e-5
end #src

