Skip to content

Commit

Permalink
Use new interpolation routine in Var
Browse files Browse the repository at this point in the history
This commit removes Interpolations.jl from Var.jl. To do this, the
function `_make_interpolant` was removed. Three new functions are added
which are `_check_interpolant`, `interpolate_point`, and
`interpolate_points`, where the latter two functions replace the
functionality of `_make_interpolant`. Furthermore, the function
`_find_extp_bound_cond` was refactored to `_find_extp_bound_conds` which
find multiple extrapolation condtions using `_find_extp_bound_cond`
which is refactored to find the extrapolation condition for a single
point.

All functions that use an interpolant are updated to use the new
interpolation routine. The test for computing the bias in Atmos changes
to check approximately close to 0.0, due to floating point errors. The
tests that check for errors when interpolating out of bounds now check
for ErrorException instead of BoundsError.
  • Loading branch information
ph-kev committed Dec 13, 2024
1 parent c55b718 commit 25c7808
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 72 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ However, functions like `resampled_as` and interpolating using a `OutputVar` wil
as an interpolant must be generated. This means repeated calls to these functions will be
slower compared to the previous versions of ClimaAnalysis.

## Add interpolation routine
With this release, any functions that rely on interpolation now uses the interpolation
routine written for ClimaAnalysis instead of Interpolations.jl. This substantially reduce
the number and size of allocations when using these functions.

v0.5.12
-------

Expand Down
140 changes: 91 additions & 49 deletions src/Var.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Statistics: mean
import NaNStatistics: nanmean

import ..Numerics
import ..Interpolations
import ..Utils:
nearest_index,
seconds_to_prettystr,
Expand Down Expand Up @@ -87,53 +88,96 @@ struct OutputVar{T <: AbstractArray, A <: AbstractArray, B, C}
end

"""
_make_interpolant(dims, data)
_check_interpolant(dims, data)
Make a linear interpolant from `dims`, a dictionary mapping dimension name to array and
`data`, an array containing data. Used in constructing a `OutputVar`.
Check if it is possible to create an interpolant.
If any element of the arrays in `dims` is a Dates.DateTime, then no interpolant is returned.
Interpolations.jl does not support interpolating on dates. If the longitudes span the entire
range and are equispaced, then a periodic boundary condition is added for the longitude
dimension. If the latitudes span the entire range and are equispaced, then a flat boundary
condition is added for the latitude dimension. In all other cases, an error is thrown when
extrapolating outside of `dim_array`.
If any element of the arrays in `dims` is a Dates.DateTime, then an error is returned. If
the longitudes span the entire range and are equispaced, then a periodic boundary condition
is added for the longitude dimension. If the latitudes span the entire range and are
equispaced, then a flat boundary condition is added for the latitude dimension. In all other
cases, an error is thrown when extrapolating outside of `dim_array`.
"""
function _make_interpolant(dims, data)
# If any element is DateTime, then return nothing for the interpolant because
# Interpolations.jl do not support DateTimes
function _check_interpolant(dims)
# If any element is DateTime, then return an error
# ClimaAnalysis does not support interpolating on dates
for dim_array in values(dims)
eltype(dim_array) <: Dates.DateTime && return nothing
eltype(dim_array) <: Dates.DateTime && return error(
"An interpolant cannot be made because interpolating on dates is not possible",
)
end

# We can only create interpolants when we have 1D dimensions
if isempty(dims) || any(d -> ndims(d) != 1 || length(d) == 1, values(dims))
return nothing
return error(
"An interpolant cannot be made because the dimensions are not 1D",
)
end

# Dimensions are all 1D, check that the knots are in increasing order (as required by
# Interpolations.jl)
# our interpolation routine)
for (dim_name, dim_array) in dims
if !issorted(dim_array)
@warn "Dimension $dim_name is not in increasing order. An interpolant will not be created. See Var.reverse_dim if the dimension is in decreasing order"
return nothing
return error(
"Dimension $dim_name is not in increasing order. An interpolant will not be created. See Var.reverse_dim if the dimension is in decreasing order",
)
end
end
return nothing
end

# Find boundary conditions for extrapolation
extp_bound_conds = (
_find_extp_bound_cond(dim_name, dim_array) for
(dim_name, dim_array) in dims
)
"""
interpolate_point(point, dims, data)
Linearly interpolate the point using `dims` and `data`.
dims_tuple = tuple(values(dims)...)
extp_bound_conds_tuple = tuple(extp_bound_conds...)
return Intp.extrapolate(
Intp.interpolate(dims_tuple, data, Intp.Gridded(Intp.Linear())),
extp_bound_conds_tuple,
Extrapolation conditions are determined by `_find_extp_bound_conds`.
"""
function interpolate_point(point, dims, data)
_check_interpolant(dims)
extp_bound_conds = _find_extp_bound_conds(dims)
return Interpolations.linear_interpolate(
point,
Tuple(values(dims)),
data,
extp_bound_conds,
)
end

"""
interpolate_points(points, dims, data)
Linearly interpolate the points using `dims` and `data`.
Extrapolation conditions are determined by `_find_extp_bound_conds`.
"""
function interpolate_points(points, dims, data)
_check_interpolant(dims)
extp_bound_conds = _find_extp_bound_conds(dims)
dim_arrays_tuple = Tuple(values(dims))
interpolated_arr = [
Interpolations.linear_interpolate(
point,
dim_arrays_tuple,
data,
extp_bound_conds,
) for point in points
]
return interpolated_arr
end

"""
_find_extp_bound_conds(dims)
Find the appropriate boundary conditions given the `dims` of an `OutputVar`.
"""
function _find_extp_bound_conds(dims)
return (
_find_extp_bound_cond(dim_name, dim_array) for
(dim_name, dim_array) in dims
) |> Tuple
end

"""
_find_extp_bound_cond(dim_name, dim_array)
Expand All @@ -151,17 +195,17 @@ function _find_extp_bound_cond(dim_name, dim_array)
conventional_dim_name(dim_name) == "longitude" &&
_isequispaced(dim_array) &&
isapprox(dim_size + dsize, 360.0)
) && return Intp.Periodic()
) && return Interpolations.extp_cond_periodic()
(
conventional_dim_name(dim_name) == "longitude" &&
(dim_array[end] - dim_array[begin]) 360.0
) && return Intp.Periodic()
) && return Interpolations.extp_cond_periodic()
(
conventional_dim_name(dim_name) == "latitude" &&
_isequispaced(dim_array) &&
isapprox(dim_size + dsize, 180.0)
) && return Intp.Flat()
return Intp.Throw()
) && return Interpolations.extp_cond_flat()
return Interpolations.extp_cond_throw()
end

function OutputVar(attribs, dims, dim_attribs, data)
Expand Down Expand Up @@ -981,11 +1025,11 @@ multilinear interpolation.
Extrapolation is now allowed and will throw a `BoundsError` in most cases.
If any element of the arrays of the dimensions is a Dates.DateTime, then interpolation is
not possible. Interpolations.jl do not support making interpolations for dates. If the
longitudes span the entire range and are equispaced, then a periodic boundary condition is
added for the longitude dimension. If the latitudes span the entire range and are
equispaced, then a flat boundary condition is added for the latitude dimension. In all other
cases, an error is thrown when extrapolating outside of the array of the dimension.
not possible. If the longitudes span the entire range and are equispaced, then a periodic
boundary condition is added for the longitude dimension. If the latitudes span the entire
range and are equispaced, then a flat boundary condition is added for the latitude
dimension. In all other cases, an error is thrown when extrapolating outside of the array of
the dimension.
Example
=======
Expand All @@ -1005,8 +1049,7 @@ julia> var2d = ClimaAnalysis.OutputVar(Dict("time" => time, "z" => z), data); va
```
"""
function (x::OutputVar)(target_coord)
itp = _make_interpolant(x.dims, x.data)
return itp(target_coord...)
return interpolate_point(target_coord, x.dims, x.data)
end

"""
Expand Down Expand Up @@ -1143,9 +1186,8 @@ function resampled_as(src_var::OutputVar, dest_var::OutputVar)
src_var = reordered_as(src_var, dest_var)
_check_dims_consistent(src_var, dest_var)

itp = _make_interpolant(src_var.dims, src_var.data)
src_resampled_data =
[itp(pt...) for pt in Base.product(values(dest_var.dims)...)]
coords = Base.product(values(dest_var.dims)...)
src_resampled_data = interpolate_points(coords, src_var.dims, src_var.data)

# Construct new OutputVar to return
src_var_ret_dims = empty(src_var.dims)
Expand Down Expand Up @@ -1756,14 +1798,14 @@ function make_lonlat_mask(

# Resample so that the mask match up with the grid of var
# Round because linear resampling is done and we want the mask to be only ones and zeros
intp = _make_interpolant(mask_var.dims, mask_var.data)
mask_arr =
[
intp(pt...) for pt in Base.product(
input_var.dims[longitude_name(input_var)],
input_var.dims[latitude_name(input_var)],
)
] .|> round
coords = [
pt for pt in Base.product(
input_var.dims[longitude_name(input_var)],
input_var.dims[latitude_name(input_var)],
)
]
mask_arr = interpolate_points(coords, mask_var.dims, mask_var.data)
mask_arr .= mask_arr .|> round

# Reshape data for broadcasting
lon_idx = input_var.dim2index[longitude_name(input_var)]
Expand Down
2 changes: 1 addition & 1 deletion test/test_Atmos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ end
sim_pressure = pressure3D,
obs_pressure = pressure3D,
)
@test global_rmse_pfull == 0.0
@test isapprox(global_rmse_pfull, 0.0, atol = 1e-11)

# Test if the computation is the same as a manual computation
zero_data = zeros(size(data))
Expand Down
53 changes: 31 additions & 22 deletions test/test_Var.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,35 +92,43 @@ end
lon = 0.5:1.0:359.5 |> collect
lat = -89.5:1.0:89.5 |> collect
time = 1.0:100 |> collect
data = ones(length(lon), length(lat), length(time))
dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time])
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test intp.et == (Intp.Periodic(), Intp.Flat(), Intp.Throw())
extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims)
@test extp_conds == (
ClimaAnalysis.Interpolations.extp_cond_periodic(),
ClimaAnalysis.Interpolations.extp_cond_flat(),
ClimaAnalysis.Interpolations.extp_cond_throw(),
)

# Not equispaced for lon and lat
lon = 0.5:1.0:359.5 |> collect |> x -> push!(x, 42.0) |> sort
lat = -89.5:1.0:89.5 |> collect |> x -> push!(x, 42.0) |> sort
time = 1.0:100 |> collect
data = ones(length(lon), length(lat), length(time))
dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time])
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test intp.et == (Intp.Throw(), Intp.Throw(), Intp.Throw())
extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims)
@test extp_conds == (
ClimaAnalysis.Interpolations.extp_cond_throw(),
ClimaAnalysis.Interpolations.extp_cond_throw(),
ClimaAnalysis.Interpolations.extp_cond_throw(),
)

# Does not span entire range for and lat
lon = 0.5:1.0:350.5 |> collect
lat = -89.5:1.0:80.5 |> collect
time = 1.0:100 |> collect
data = ones(length(lon), length(lat), length(time))
dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time])
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test intp.et == (Intp.Throw(), Intp.Throw(), Intp.Throw())
extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims)
@test extp_conds == (
ClimaAnalysis.Interpolations.extp_cond_throw(),
ClimaAnalysis.Interpolations.extp_cond_throw(),
ClimaAnalysis.Interpolations.extp_cond_throw(),
)

# Lon is exactly 360 degrees
lon = 0.0:1.0:360.0 |> collect
data = ones(length(lon))
dims = OrderedDict(["lon" => lon])
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test intp.et == (Intp.Periodic(),)
extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims)
@test extp_conds == (ClimaAnalysis.Interpolations.extp_cond_periodic(),)

# Dates for the time dimension
lon = 0.5:1.0:359.5 |> collect
Expand All @@ -130,17 +138,18 @@ end
Dates.DateTime(2020, 3, 1, 1, 2),
Dates.DateTime(2020, 3, 1, 1, 3),
]
data = ones(length(lon), length(lat), length(time))
dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time])
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test isnothing(intp)
@test_throws ErrorException ClimaAnalysis.Var._check_interpolant(dims)

# 2D dimensions
arb_dim = reshape(collect(range(-89.5, 89.5, 16)), (4, 4))
data = collect(1:16)
dims = OrderedDict(["arb_dim" => arb_dim])
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test isnothing(intp)
@test_throws ErrorException ClimaAnalysis.Var._check_interpolant(dims)

# Dimensions are not in increasing order
lon = [0.5, 42.0, 1.5, 110.0]
dims = OrderedDict(["lon" => lon])
@test_throws ErrorException ClimaAnalysis.Var._check_interpolant(dims)
end

@testset "empty" begin
Expand Down Expand Up @@ -497,6 +506,7 @@ end
@test ClimaAnalysis.pressure_name(pressure_var) == "pfull"
end

# FIX THIS
@testset "Interpolation" begin
# 1D interpolation with linear data, should yield correct results
long = -175.0:175.0 |> collect
Expand All @@ -507,7 +517,7 @@ end
@test longvar.([10.5, 20.5]) == [10.5, 20.5]

# Test error for data outside of range
@test_throws BoundsError longvar(200.0)
@test_throws ErrorException longvar(200.0)

# 2D interpolation with linear data, should yield correct results
time = 100.0:110.0 |> collect
Expand Down Expand Up @@ -812,7 +822,7 @@ end
@test src_var.data == ClimaAnalysis.resampled_as(src_var, src_var).data
resampled_var = ClimaAnalysis.resampled_as(src_var, dest_var)
@test resampled_var.data == reshape(1.0:(181 * 91), (181, 91))[1:91, 1:46]
@test_throws BoundsError ClimaAnalysis.resampled_as(dest_var, src_var)
@test_throws ErrorException ClimaAnalysis.resampled_as(dest_var, src_var)

# BoundsError check
src_long = 90.0:120.0 |> collect
Expand All @@ -837,7 +847,7 @@ end
dest_data,
)

@test_throws BoundsError ClimaAnalysis.resampled_as(src_var, dest_var)
@test_throws ErrorException ClimaAnalysis.resampled_as(src_var, dest_var)
end

@testset "Units" begin
Expand Down Expand Up @@ -1889,7 +1899,6 @@ end
attribs = Dict("long_name" => "hi")
dim_attribs = OrderedDict(["lon" => Dict("units" => "deg")])
var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)
@test isnothing(ClimaAnalysis.Var._make_interpolant(dims, data))

reverse_var = ClimaAnalysis.reverse_dim(var, "lat")
@test reverse(lat) == reverse_var.dims["lat"]
Expand Down

0 comments on commit 25c7808

Please sign in to comment.