Skip to content

Commit

Permalink
Introduce Wrapper (instead of specializing kernel evaluation on DiffP…
Browse files Browse the repository at this point in the history
…t) (#10)

* enableDiffWrap instead of general

* jldoctest

* trim jldoctest

* warning docs, shorten name

* identify broken tests

* push broken state (upstream issue)

JuliaGaussianProcesses/KernelFunctions.jl#517
  • Loading branch information
FelixBenning authored Jun 5, 2023
1 parent 4fb8d88 commit 917c4da
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/DifferentiableKernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Reexport

@reexport using KernelFunctions

export partial
export partial, EnableDiff

include("multiOutput.jl")
include("partial.jl")
Expand Down
62 changes: 41 additions & 21 deletions src/diffKernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,51 @@ import LinearAlgebra as LA
using KernelFunctions: SimpleKernel, Kernel

"""
_evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel}
diffKernelCall(k::T, (x,px)::DiffPt, (y,py)::DiffPt) where {Dim, T<:Kernel}
implements `(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim})` for all kernel types. But since
generics are not allowed in the syntax above by the dispatch system, this
redirection over `_evaluate` is necessary
unboxes the partial instructions from DiffPt and applies them to k,
evaluates them at the positions of DiffPt
specialization for DiffPt. Unboxes the partial instructions from DiffPt and
applies them to k, evaluates them at the positions of DiffPt
"""
function _evaluate(k::T, (x,px)::DiffPt, (y,py)::DiffPt) where {T<:Kernel}
function diffKernelCall(k::T, (x,px)::DiffPt, (y,py)::DiffPt) where {T<:Kernel}
return apply_partial(k, px.indices, py.indices)(x, y)
end

#=
This is a hack to work around the fact that the `where {T<:Kernel}` clause is
not allowed for the `(::T)(x,y)` syntax. If we were to only implement
```julia
(::Kernel)(::DiffPt,::DiffPt)
```
then julia would not know whether to use
`(::SpecialKernel)(x,y)` or `(::Kernel)(x::DiffPt, y::DiffPt)`
"""
EnableDiff
A thin wrapper around Kernels enabling the machinery which allows you to
input (x, ∂ᵢ), (y, ∂ⱼ) where ∂ᵢ, ∂ⱼ are of `Partial` type (see [partial](@ref)) in order
to calculate
``
k((x, ∂ᵢ), (y,∂ⱼ)) = \\text{Cov}(\\partial_i Z(x), \\partial_j Z(y))
``
for ``Z`` with ``k(x,y) = \\text{Cov}(Z(x), Z(y))``.
!!! warning Only apply this wrapper at the very end. Kerneltransformations
should be applied beforehand.
!!! info While this machinery could in principle be enabled for all `Kernel` by default,
the covariance of derivatives of an isotropic kernel are no longer isotropic.
This forces the use of less specialized methods. So for now you have to opt-in
with this Wrapper.
Example:
```jldoctest
julia> k = EnableDiff(SEKernel());
julia> k((0, partial(1)), 0) # calculate Cov(∂₁Z(0), Z(0))
0.0
julia> k(0,0) # normal input still works
1.0
```
=#
for T in [SimpleKernel, Kernel] #subtypes(Kernel)
(k::T)(x::DiffPt, y::DiffPt) = _evaluate(k, x, y)
(k::T)(x::DiffPt, y) = _evaluate(k, x,(y, partial()))
(k::T)(x, y::DiffPt) = _evaluate(k, (x, partial()), y)
"""
struct EnableDiff{T<:Kernel} <: Kernel
kernel::T
end
(k::EnableDiff)(x::DiffPt, y::DiffPt) = diffKernelCall(k.kernel, x, y)
(k::EnableDiff)(x::DiffPt, y) = diffKernelCall(k.kernel, x,(y, partial()))
(k::EnableDiff)(x, y::DiffPt) = diffKernelCall(k.kernel, (x, partial()), y)
(k::EnableDiff)(x, y) = k.kernel(x,y) # Fall through case

2 changes: 1 addition & 1 deletion src/partial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,6 @@ i.e. 2*dim dimensional input
function apply_partial(
k, partials_x::Tuple{Vararg{T}}, partials_y::Tuple{Vararg{T}}
) where {T<:IndexType}
local f(x, y) = apply_partial(t -> k(t, y), partials_x...)(x)
f(x, y) = apply_partial(t -> k(t, y), partials_x...)(x)
return (x, y) -> apply_partial(t -> f(x, t), partials_y...)(y)
end
32 changes: 23 additions & 9 deletions test/diffKernel.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,39 @@
@testset "diffKernel" begin
@testset "smoke test" begin
k = MaternKernel()
k(1, 1)
k = EnableDiff(MaternKernel())
k2 = MaternKernel()
@test k(1, 1) == k2(1, 1)
k(1, (1, partial(1, 1))) # Cov(Z(x), ∂₁∂₁Z(y)) where x=1, y=1
k(([1], partial(1)), [2]) # Cov(∂₁Z(x), Z(y)) where x=[1], y=[2]
k(([1, 2], partial(1)), ([1, 2], partial(2)))# Cov(∂₁Z(x), ∂₂Z(y)) where x=[1,2], y=[1,2]
end

@testset "Sanity Checks with $k" for k in [SEKernel()]
@testset "Sanity Checks with $k1" for k1 in [
SEKernel(),
MaternKernel=5),
RationalQuadraticKernel(),
SEKernel() + RationalQuadraticKernel()
]
k = EnableDiff(k1)
for x in [0, 1, -1, 42]
# for stationary kernels Cov(∂Z(x) , Z(x)) = 0
@test k((x, partial(1)), x) 0
# correlation with self should be positive
## This fails for Matern and RationalQuadraticKernel
# because its implementation branches on x == y resulting in a zero derivative
# (cf. https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/issues/517)
@test k((x, partial(1)), (x, partial(1))) > 0

# the slope should be positively correlated with a point further down
@test k(
(x, partial(1)), # slope
x + 1e-1, # point further down
x + 1e-2, # point further down
) > 0

# correlation with self should be positive
@test k((x, partial(1)), (x, partial(1))) > 0
@testset "Stationary Tests" begin
@test k((x, partial(1)), x) == 0 # expect Cov(∂Z(x) , Z(x)) == 0

@testset "Isotropic Tests" begin
@test k(([1, 2], partial(1)), ([1, 2], partial(2))) == 0 # cross covariance should be zero
end
end
end
end
end
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using KernelFunctions: KernelFunctions as KF, MaternKernel, SEKernel
using DifferentiableKernelFunctions: DifferentiableKernelFunctions as DKF, DiffPt, partial
using KernelFunctions: KernelFunctions as KF, MaternKernel, SEKernel, RationalQuadraticKernel
using DifferentiableKernelFunctions: DifferentiableKernelFunctions as DKF, EnableDiff, partial
using ProductArrays: productArray
using Test

Expand Down

0 comments on commit 917c4da

Please sign in to comment.