From 2d172125e72ca932133266c3b23d0865bec539df Mon Sep 17 00:00:00 2001 From: st-- Date: Sat, 18 Dec 2021 23:43:30 +0200 Subject: [PATCH] Zygote AD failure workarounds & test cleanup (#414) Zygote AD failures: * revert #409 (test_utils workaround for broken Zygote - now working again) * disable broken Zygote AD test for ChainTransform Improved tests: * finer-grained testsets * add missing test cases to test_AD * replace test_FiniteDiff with test_AD(..., :FiniteDiff, ...) * remove code duplication --- test/test_utils.jl | 160 ++++++++++--------------------- test/transform/chaintransform.jl | 4 +- 2 files changed, 53 insertions(+), 111 deletions(-) diff --git a/test/test_utils.jl b/test/test_utils.jl index 141fcb96e..1871a99ca 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -69,6 +69,10 @@ function gradient(f, ::Val{:FiniteDiff}, args) return first(FiniteDifferences.grad(FDM, f, args)) end +function compare_gradient(f, ::Val{:FiniteDiff}, args) + @test_nowarn gradient(f, :FiniteDiff, args) +end + function compare_gradient(f, AD::Symbol, args) grad_AD = gradient(f, AD, args) grad_FD = gradient(f, :FiniteDiff, args) @@ -88,7 +92,7 @@ testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B)) function test_ADs( kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3] ) - test_fd = test_FiniteDiff(kernelfunction, args, dims) + test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims) if !test_fd.anynonpass for AD in ADs test_AD(AD, kernelfunction, args, dims) @@ -100,7 +104,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context()) @inferred f(args...) @inferred Zygote._pullback(ctx, f, args...) out, pb = Zygote._pullback(ctx, f, args...) - @test_throws ErrorException @inferred pb(out) + @inferred pb(out) end function test_ADs( @@ -114,70 +118,6 @@ function test_ADs( end end -function test_FiniteDiff(kernelfunction, args=nothing, dims=[3, 3]) - # Init arguments : - k = if args === nothing - kernelfunction() - else - kernelfunction(args) - end - rng = MersenneTwister(42) - @testset "FiniteDifferences" begin - if k isa SimpleKernel - for d in log.([eps(), rand(rng)]) - @test_nowarn gradient(:FiniteDiff, [d]) do x - kappa(k, exp(first(x))) - end - end - end - ## Testing Kernel Functions - x = rand(rng, dims[1]) - y = rand(rng, dims[1]) - @test_nowarn gradient(:FiniteDiff, x) do x - k(x, y) - end - if !(args === nothing) - @test_nowarn gradient(:FiniteDiff, args) do p - kernelfunction(p)(x, y) - end - end - ## Testing Kernel Matrices - A = rand(rng, dims...) - B = rand(rng, dims...) - for dim in 1:2 - @test_nowarn gradient(:FiniteDiff, A) do a - testfunction(k, a, dim) - end - @test_nowarn gradient(:FiniteDiff, A) do a - testfunction(k, a, B, dim) - end - @test_nowarn gradient(:FiniteDiff, B) do b - testfunction(k, A, b, dim) - end - if !(args === nothing) - @test_nowarn gradient(:FiniteDiff, args) do p - testfunction(kernelfunction(p), A, B, dim) - end - end - - @test_nowarn gradient(:FiniteDiff, A) do a - testdiagfunction(k, a, dim) - end - @test_nowarn gradient(:FiniteDiff, A) do a - testdiagfunction(k, a, B, dim) - end - @test_nowarn gradient(:FiniteDiff, B) do b - testdiagfunction(k, A, b, dim) - end - if args !== nothing - @test_nowarn gradient(:FiniteDiff, args) do p - testdiagfunction(kernelfunction(p), A, B, dim) - end - end - end - end -end - function test_FiniteDiff(k::MOKernel, dims=(in=3, out=2, obs=3)) rng = MersenneTwister(42) @testset "FiniteDifferences" begin @@ -224,68 +164,68 @@ end function test_AD(AD::Symbol, kernelfunction, args=nothing, dims=[3, 3]) @testset "$(AD)" begin - # Test kappa function k = if args === nothing kernelfunction() else kernelfunction(args) end rng = MersenneTwister(42) + if k isa SimpleKernel - for d in log.([eps(), rand(rng)]) - compare_gradient(AD, [d]) do x - kappa(k, exp(x[1])) + @testset "kappa function" begin + for d in log.([eps(), rand(rng)]) + compare_gradient(AD, [d]) do x + kappa(k, exp(x[1])) + end end end end - # Testing kernel evaluations - x = rand(rng, dims[1]) - y = rand(rng, dims[1]) - compare_gradient(AD, x) do x - k(x, y) - end - compare_gradient(AD, y) do y - k(x, y) - end - if !(args === nothing) - compare_gradient(AD, args) do p - kernelfunction(p)(x, y) - end - end - # Testing kernel matrices - A = rand(rng, dims...) - B = rand(rng, dims...) - for dim in 1:2 - compare_gradient(AD, A) do a - testfunction(k, a, dim) - end - compare_gradient(AD, A) do a - testfunction(k, a, B, dim) + + @testset "kernel evaluations" begin + x = rand(rng, dims[1]) + y = rand(rng, dims[1]) + compare_gradient(AD, x) do x + k(x, y) end - compare_gradient(AD, B) do b - testfunction(k, A, b, dim) + compare_gradient(AD, y) do y + k(x, y) end if !(args === nothing) - compare_gradient(AD, args) do p - testfunction(kernelfunction(p), A, dim) + @testset "hyperparameters" begin + compare_gradient(AD, args) do p + kernelfunction(p)(x, y) + end end end + end - compare_gradient(AD, A) do a - testdiagfunction(k, a, dim) - end - compare_gradient(AD, A) do a - testdiagfunction(k, a, B, dim) - end - compare_gradient(AD, B) do b - testdiagfunction(k, A, b, dim) - end - if args !== nothing - compare_gradient(AD, args) do p - testdiagfunction(kernelfunction(p), A, dim) + @testset "kernel matrices" begin + A = rand(rng, dims...) + B = rand(rng, dims...) + @testset "$(_testfn)" for _testfn in (testfunction, testdiagfunction) + for dim in 1:2 + compare_gradient(AD, A) do a + _testfn(k, a, dim) + end + compare_gradient(AD, A) do a + _testfn(k, a, B, dim) + end + compare_gradient(AD, B) do b + _testfn(k, A, b, dim) + end + if !(args === nothing) + @testset "hyperparameters" begin + compare_gradient(AD, args) do p + _testfn(kernelfunction(p), A, dim) + end + compare_gradient(AD, args) do p + _testfn(kernelfunction(p), A, B, dim) + end + end + end end end - end + end # kernel matrices end end diff --git a/test/transform/chaintransform.jl b/test/transform/chaintransform.jl index 6e3d8a44f..f8b19cffe 100644 --- a/test/transform/chaintransform.jl +++ b/test/transform/chaintransform.jl @@ -24,6 +24,8 @@ @test repr(tp ∘ tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)" test_ADs( x -> SEKernel() ∘ (ScaleTransform(exp(x[1])) ∘ ARDTransform(exp.(x[2:4]))), - randn(rng, 4), + randn(rng, 4); + ADs=[:ForwardDiff, :ReverseDiff], # explicitly pass ADs to exclude :Zygote ) + @test_broken "test_AD of chain transform is currently broken in Zygote, see GitHub issue #263" end