Skip to content

Commit

Permalink
Merge pull request #90 from DrChainsaw/funcmemo
Browse files Browse the repository at this point in the history
Change examples and test to use explicit gradients
  • Loading branch information
DrChainsaw authored Aug 7, 2023
2 parents 7d1150e + 36d24bb commit f502e5f
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 96 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NaiveNASflux"
uuid = "85610aed-7d32-5e57-bb50-4c2e1c9e7997"
version = "2.0.9"
version = "2.0.10"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -9,6 +9,8 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NaiveNASlib = "bd45eb3e-47ce-54bd-9eaf-e86c5f900853"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Expand All @@ -20,6 +22,8 @@ Flux = "0.13, 0.14"
Functors = "0.2, 0.3, 0.4"
JuMP = "0.19, 0.20, 0.21, 0.22, 0.23, 1"
NaiveNASlib = "2.0.6"
Optimisers = "0.2"
PrecompileTools = "1"
Reexport = "0.2.0, 1.0"
Setfield = "0.3.4, 0.5, 0.6, 0.7, 0.8, 1.0"
julia = "1"
5 changes: 4 additions & 1 deletion src/NaiveNASflux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using NaiveNASlib.Extend, NaiveNASlib.Advanced
import Flux
using Flux: Dense, Conv, ConvTranspose, CrossCor, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
MaxPool, MeanPool, Dropout, AlphaDropout, GlobalMaxPool, GlobalMeanPool, cpu
import Optimisers
import Functors
using Functors: @functor
using Statistics
Expand Down Expand Up @@ -39,4 +40,6 @@ include("neuronutility.jl")
# Stuff to integrate with Zygote
include("chainrules.jl")

end # module
include("precompile.jl")

end # module
8 changes: 4 additions & 4 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ ChainRulesCore.@non_differentiable mutate(args...)
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, m::MutableLayer, args...)
res, back = rrule_via_ad(config, m.layer, args...)
function MutableLayer_back(Δ)
δlayer, δargs = back(Δ)
Tangent{MutableLayer}(layer=δlayer), δargs
δs = back(Δ)
Tangent{MutableLayer}(layer=δs[1]), δs[2:end]...
end
return res, MutableLayer_back
end
Expand All @@ -23,8 +23,8 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, m::LazyMutab
forcemutation(m)
res, back = rrule_via_ad(config, m.mutable, args...)
function LazyMutable_back(Δ)
δmutable, δargs = back(Δ)
Tangent{LazyMutable}(mutable=δmutable), δargs
δs = back(Δ)
Tangent{LazyMutable}(mutable=δs[1]), δs[2:end]...
end
return res, LazyMutable_back
end
Expand Down
1 change: 1 addition & 0 deletions src/mutable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Base.Broadcast.broadcastable(m::AbstractMutableComp) = Ref(m)

# Not possible in julia <= 1.1. See #14919
# (m::AbstractMutableComp)(x...) = layer(m)(x...)
layer(m::AbstractMutableComp) = layer(wrapped(m))
layertype(m::AbstractMutableComp) = layertype(layer(m))

NaiveNASlib.nin(m::AbstractMutableComp) = nin(wrapped(m))
Expand Down
80 changes: 44 additions & 36 deletions src/neuronutility.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""
ActivationContribution{L,M} <: AbstractMutableComp
ActivationContribution{L,C,M} <: AbstractMutableComp
ActivationContribution(l)
ActivationContribution(l, method)
Calculate neuron utility based on activations and gradients using `method`.
Calculate neuron utility based on activations and gradients using `method`.
Designed to be used as `layerfun` argument to [`fluxvertex`](@ref).
Can be a performance bottleneck in cases with large activations. Use [`NeuronUtilityEvery`](@ref) to mitigate.
Expand All @@ -13,58 +15,63 @@ Short summary is that the first order taylor approximation of the optimization p
boils down to: "the ones which minimize `abs(gradient * activation)`" (assuming parameter independence).
"""
struct ActivationContribution{L,M} <: AbstractMutableComp
struct ActivationContribution{L,C,M} <: AbstractMutableComp
layer::L
contribution::Base.RefValue{Any} # Type of activation not known yet :( Also leave some room for experimenting with things like storing the metric on the GPU
contribution::C
method::M
end
# We use eps(Float32) here because we don't want parameters from new layers to have:
# 1) higher utility than parameters from existing layers
# 2) zero utility since that will often make the optimizer remove them completely
# eps(Float32) is typically smaller than the optimizer tolerance, but NaiveNASlib tries to rescale
ActivationContribution(l::AbstractMutableComp, method = Ewma(0.05f0)) = ActivationContribution(l, Ref{Any}(fill(eps(Float32), nout(l))), method)
ActivationContribution(l, method = Ewma(0.05f0)) = ActivationContribution(l, Ref{Any}(missing), method)

layer(m::ActivationContribution) = layer(m.layer)
layertype(m::ActivationContribution) = layertype(m.layer)
wrapped(m::ActivationContribution) = m.layer
ActivationContribution(l::AbstractMutableComp, method = Ewma(0.05f0)) = ActivationContribution(l, fill(eps(Float32), nout(l)), method)
ActivationContribution(l, method = Ewma(0.05f0)) = ActivationContribution(l, Float32[], method)

Flux.trainable(m::ActivationContribution) = (;layer = Flux.trainable(wrapped(m)))
@functor ActivationContribution

function Functors.functor(t::Type{<:ActivationContribution}, m::ActivationContribution)
_functor(t, m.layer, m.contribution[], m.method)
end

function Functors.functor(t::Type{<:ActivationContribution}, m)
_functor(t, m.layer, m.contribution, m.method)
end

function _functor(::Type{<:ActivationContribution}, layer, contribution, method)
return (;layer, contribution, method), function(y)
ActivationContribution(y.layer, Ref{Any}(y.contribution), y.method)
end
end

function(m::ActivationContribution)(x...)
act = wrapped(m)(x...)
wrapped(m::ActivationContribution) = m.layer

return Flux.Zygote.hook(act) do grad
grad === nothing && return grad
m.contribution[] = m.method(m.contribution[], act, grad)
return grad
# We do train contribution in some sense, but we don't want Flux to do it
# We could create a "fake" gradient in the rrule and let the optimizer rule update it for us
# (rather than using our own Ewma), but it is probably not desirable to mix the model parameter update
# strategy with the activation contribution strategy.
Flux.trainable(m::ActivationContribution) = (;layer = Flux.trainable(m.layer))

# Just passthrough when not taking gradients.
(m::ActivationContribution)(x...) = wrapped(m)(x...)

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, m::T, x...) where T <:ActivationContribution
act, back = rrule_via_ad(config, wrapped(m), x...)

function ActivationContribution_back(Δ)
if length(m.contribution) === 0
newcontribution = m.method(missing, act, Δ)
resize!(m.contribution, length(newcontribution))
copyto!(m.contribution, newcontribution)
else
copyto!(m.contribution, m.method(m.contribution, act, Δ))
end

δs = back(Δ)
Tangent{T}(layer=δs[1]), δs[2:end]...
end
return act, ActivationContribution_back
end

actdim(nd::Integer) = nd - 1

function NaiveNASlib.Δsize!(m::ActivationContribution, inputs::AbstractVector, outputs::AbstractVector; kwargs...)
if m.contribution[] !== missing
if length(m.contribution) !== 0
# This tends to happen when we are measuring contribution for a concatenation and we have added an extra input edge
# TODO: Try to find another fix, perhaps we need to ensure that nout(v) if v wraps an ActivationContribution always return
# the length of m.contribution
outputs[outputs .> length(m.contribution[])] .= -1
outputs[outputs .> length(m.contribution)] .= -1
newcontribution = select(m.contribution, 1 => outputs; newfun = (args...) -> eps(eltype(m.contribution)))
resize!(m.contribution, length(newcontribution))
copyto!(m.contribution, newcontribution)
#else
# no need to select anything
end
m.contribution[] = select(m.contribution[], 1 => outputs; newfun = (args...) -> eps(eltype(m.contribution[])))
NaiveNASlib.Δsize!(wrapped(m), inputs, outputs; kwargs...)
end

Expand Down Expand Up @@ -100,7 +107,8 @@ function neuronutility(lm::LazyMutable)
forcemutation(lm)
neuronutility(wrapped(lm))
end
neuronutility(m::ActivationContribution) = m.contribution[]
# Return missing to maintain API since previous versions used missing as sentinel value instead of empty vector
neuronutility(m::ActivationContribution) = isempty(m.contribution) ? missing : m.contribution
neuronutility(l) = neuronutility(layertype(l), l)

# Default: mean of abs of weights + bias. Not a very good metric, but should be better than random
Expand Down Expand Up @@ -134,7 +142,7 @@ neuronutility_safe(::MutationSizeTrait, v) = clean_values(cpu(neuronutility(v)))
neuronutility_safe(m::AbstractMutableComp) = clean_values(cpu(neuronutility(m)))

clean_values(::Missing) = 1
clean_values(a::AbstractArray) = replace(a, NaN => -100, Inf => -100, -Inf => -100)
clean_values(a::AbstractArray) = length(a) === 0 ? 1 : replace(a, NaN => -100, Inf => -100, -Inf => -100)

"""
neuronutilitytaylor(currval, act, grad)
Expand Down
28 changes: 28 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using PrecompileTools

let
@setup_workload begin
iv1 = denseinputvertex("iv1", 1)
v1 = fluxvertex("v1", Dense(nout(iv1) => 1), iv1)
v2 = concat("v2", v1, v1; layerfun=ActivationContribution)
v3 = concat("v3", v2,v1,iv1)
v4 = "v4" >> v3 + v3
v5 = "v5" >> v4 + v4 + v4
v6 = fluxvertex("v6", Dense(nout(v5) => 1), v5; layerfun = ActivationContribution LazyMutable)

g1 = CompGraph(iv1, v6)
x1 = ones(Float32, 1, 1)

@compile_workload begin
iv1 = denseinputvertex("iv1", 1)
fluxvertex("v1", Dense(nout(iv1) => 1), iv1)

g1(x1)
Flux.@code_adjoint g1(x1)
#Optimisers.setup(Optimisers.Descent(0.1f0), g1)
#Flux.gradient((g,x) -> sum(g(x)), g1, x1)

Δnout!(v3 => relaxed(2))
end
end
end
67 changes: 31 additions & 36 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,13 @@
end

import Optimisers
function with_explicit_grads(f)
try
NaiveNASlib.enable_explicit_gradients[] = true
f()
finally
NaiveNASlib.enable_explicit_gradients[] = false
end
end

teststructs(g::CompGraph, res, exp; seen=Base.IdSet()) = foreach(enumerate(outputs(g))) do (i, vo)
teststructs(g::CompGraph{<:Any, <:Tuple}, res, exp; seen=Base.IdSet()) = foreach(enumerate(outputs(g))) do (i, vo)
teststructs(vo, seen, res.outputs[i] ,exp)
end

teststructs(g::CompGraph{<:Any, <:AbstractVertex}, res, exp; seen=Base.IdSet()) = teststructs(g.outputs, seen, res.outputs ,exp)

function teststructs(v::AbstractVertex, seen, res, exp)
v in seen && return
push!(seen, v)
Expand All @@ -33,11 +27,13 @@
function _teststructs(::NaiveNASflux.InputShapeVertex, seen, res, exp, name) end

function _teststructs(v::AbstractVertex, seen, res::RT, exp, name) where RT
@testset "Check structure for $(name) of type $(typeof(v))" begin
@test hasfield(RT, :base)
end
if hasfield(RT, :base)
_teststructs(base(v), seen, res.base, exp, name)
if layertype(v) isa NaiveNASflux.FluxParLayer
@testset "Check structure for $(name) of type $(typeof(v))" begin
@test hasfield(RT, :base)
end
if hasfield(RT, :base)
_teststructs(base(v), seen, res.base, exp, name)
end
end
end
function _teststructs(v::CompVertex, seen, res::RT, exp, name) where RT
Expand All @@ -46,6 +42,7 @@
_teststructs(v.computation, res.computation, exp, name)
end
foreach(enumerate(inputs(v))) do (i, vi)
isnothing(res.inputs) && return
teststructs(vi, seen, res.inputs[i], exp)
end
end
Expand Down Expand Up @@ -111,28 +108,26 @@
@test exp[p] == res[p]
end

with_explicit_grads() do
@test gradient(sum graph, indata) == gradient(sum chain, indata)

expex = Flux.gradient(c -> sum(c(indata)), chain)
resex = Flux.gradient(g -> sum(g(indata)), graph)
teststructs(graph, resex..., expex[1].layers)

@testset "Optimisers" begin
graphstate = Optimisers.setup(Optimisers.Adam(), graph)
chainstate = Optimisers.setup(Optimisers.Adam(), chain)
@testset "Setup state" begin
teststructs(graph, graphstate, chainstate.layers)
end
# TODO: Why deepcopy needed? fmap(copy, graph) does not seem to work?
graphstate, newgraph = Optimisers.update(graphstate, deepcopy(graph), resex...)
chainstate, newchain = Optimisers.update(chainstate, chain, expex...)
@testset "New state" begin
teststructs(newgraph, graphstate, chainstate.layers)
end
@testset "New model" begin
teststructs(newgraph, Optimisers.trainable(newgraph), Optimisers.trainable(newchain).layers)
end
@test gradient(sum graph, indata) == gradient(sum chain, indata)

expex = Flux.gradient(c -> sum(c(indata)), chain)
resex = Flux.gradient(g -> sum(g(indata)), graph)
teststructs(graph, resex..., expex[1].layers)

@testset "Optimisers" begin
graphstate = Optimisers.setup(Optimisers.Adam(), graph)
chainstate = Optimisers.setup(Optimisers.Adam(), chain)
@testset "Setup state" begin
teststructs(graph, graphstate, chainstate.layers)
end
# TODO: Why deepcopy needed? fmap(copy, graph) does not seem to work?
graphstate, newgraph = Optimisers.update(graphstate, deepcopy(graph), resex...)
chainstate, newchain = Optimisers.update(chainstate, chain, expex...)
@testset "New state" begin
teststructs(newgraph, graphstate, chainstate.layers)
end
@testset "New model" begin
teststructs(newgraph, Optimisers.trainable(newgraph), Optimisers.trainable(newchain).layers)
end
end

Expand Down
19 changes: 10 additions & 9 deletions test/examples/xorpruning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,18 @@ layer2 = densevertex(layer1, 1, sigmoid)
original = CompGraph(invertex, layer2)

## Training params, nothing to see here
opt = Adam(0.1)
loss(g) = (x, y) -> mse(g(x), y)
loss(f, x, y) = mse(f(x), y)

## Training data: xor truth table: y = xor(x) just so we don't need to download a dataset.
x = Float32[0 0 1 1;
0 1 0 1]
y = Float32[0 1 1 0]

trainiter = Iterators.repeated((x,y), niters)

## Train the model
train!(loss(original), params(original), Iterators.repeated((x,y), niters), opt)
@test loss(original)(x, y) < 0.001
train!(loss, original, trainiter, Flux.setup(Adam(0.1), original))
@test loss(original, x, y) < 0.001

# With that out of the way, lets try three different ways to prune the hidden layer (vertex nr 2 in the graph).
# To make examples easier to compare, lets decide up front that we want to remove half of the hidden layer neurons
Expand All @@ -61,13 +62,13 @@ pruned_random = deepcopy(original)
Δnout!(v -> rand(nout(v)), pruned_random[2] => -nprune)

# Free lunch anyone?
@test loss(pruned_most)(x, y) >
loss(pruned_random)(x, y) >
loss(pruned_least)(x, y) >=
loss(original)(x, y)
@test loss(pruned_most, x, y) >
loss(pruned_random, x, y) >
loss(pruned_least, x, y) >=
loss(original, x, y)

# The metric calculated by [`ActivationContribution`](@ref) is actually quite good in this case.
@test loss(pruned_least)(x, y) loss(original)(x, y) atol = 1e-5
@test loss(pruned_least, x, y) loss(original, x, y) atol = 1e-5
end #src


Expand Down
Loading

2 comments on commit f502e5f

@DrChainsaw
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/89186

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.0.10 -m "<description of version>" f502e5f016d9d9a6cf76c21e5fc4202022dc35cc
git push origin v2.0.10

Please sign in to comment.