Skip to content

Commit 3092918

Browse files
authored
Merge pull request #49 from DrChainsaw/flux0.10.4
Flux0.10.4
2 parents c2127bc + 86d7917 commit 3092918

File tree

6 files changed

+26
-9
lines changed

6 files changed

+26
-9
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NaiveNASflux"
22
uuid = "85610aed-7d32-5e57-bb50-4c2e1c9e7997"
3-
version = "1.1.2"
3+
version = "1.2.0"
44

55
[deps]
66
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
@@ -14,7 +14,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1414
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1515

1616
[compat]
17-
Flux = "=0.10.3"
17+
Flux = "0.10.4"
1818
JuMP = "0.19, 0.20, 0.21"
1919
NaiveNASlib = "1"
2020
Reexport = "0.2.0"

src/NaiveNASflux.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
module NaiveNASflux
22

33
using Reexport
4-
@reexport using NaiveNASlib
5-
@reexport using Flux
4+
using NaiveNASlib
5+
using Flux
66
import Flux.Zygote: hook
77
using Statistics
88
using Setfield
@@ -37,4 +37,9 @@ include("weightinit.jl")
3737
include("functor.jl")
3838
include("zygote.jl")
3939

40+
# Reexporting before include("functor.jl") causes a warning about duplicate name (flatten) in NaiveNASlib and Flux when subtypes are called
41+
# https://discourse.julialang.org/t/avoid-error-message-for-function-name-conflict/37176/10
42+
@reexport using NaiveNASlib
43+
@reexport using Flux
44+
4045
end # module

src/types.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ layertype(l::MaxPool) = FluxNoParLayer()
5151
layertype(l::MeanPool) = FluxNoParLayer()
5252
layertype(l::Dropout) = FluxNoParLayer()
5353
layertype(l::AlphaDropout) = FluxNoParLayer()
54+
layertype(l::GlobalMaxPool) = FluxNoParLayer()
55+
layertype(l::GlobalMeanPool) = FluxNoParLayer()
5456

5557
# Compositions? Might not have any common methods...
5658
# MaxOut, Chain?

src/zygote.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ end
2828

2929
function output_loop!(memo, v)
3030
vs = nograd() do
31-
# flatten returns all input ancestors to v in topological order
31+
# NaiveNASlib.flatten returns all input ancestors to v in topological order
3232
# We also provide all vertices for which we have the output already in memo
3333
# so we don't do unnecessary calculations.
34-
flatten(v, collect(AbstractVertex, keys(memo)))[length(memo)+1:end]
34+
NaiveNASlib.flatten(v, collect(AbstractVertex, keys(memo)))[length(memo)+1:end]
3535
end
3636

3737
for vn in vs

test/pruning.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@
8181
@test size(neuron_value(l)) == (2,)
8282
end
8383

84+
@testset "Neuron value GlobalMeanPool act contrib" begin
85+
l = ml(GlobalMeanPool(), ActivationContribution, insize=2)
86+
@test ismissing(neuron_value(l))
87+
tr(l, ones(Float32, 4,4,2,5))
88+
89+
@test ismissing(minΔninfactor(l))
90+
@test ismissing(minΔnoutfactor(l))
91+
@test size(neuron_value(l)) == (2,)
92+
end
93+
8494
@testset "Elem add ActivationContribution" begin
8595
ac(l) = ActivationContribution(l)
8696
v = ac >> ml(Dense(2,3)) + ml(Dense(4,3))

test/vertex.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747

4848
W2exp, b2exp = weights(dl2)[:, inds], bias(dl2)
4949
W1exp, b1exp = weights(dl1)[inds, :], bias(dl1)[inds]
50-
apply_mutation.(flatten(dense2))
50+
apply_mutation.(NaiveNASlib.flatten(dense2))
5151

5252
@test size(CompGraph([inpt], [dense2])(collect(Float32, 1:nout(inpt)))) == (3,)
5353

@@ -63,7 +63,7 @@ end
6363
@test nin(bv) == [nout(cv)] == [4]
6464

6565
Δnin(bv, -1)
66-
apply_mutation.(flatten(bv))
66+
apply_mutation.(NaiveNASlib.flatten(bv))
6767

6868
@test nin(bv) == [nout(cv)] == [3]
6969
end
@@ -76,7 +76,7 @@ end
7676
@test nin(bv) == [nout(cv)] == [4]
7777

7878
Δnin(bv, -1)
79-
apply_mutation.(flatten(bv))
79+
apply_mutation.(NaiveNASlib.flatten(bv))
8080

8181
@test nin(bv) == [nout(cv)] == [3]
8282
end

0 commit comments

Comments
 (0)