Skip to content

Commit

Permalink
Remove interpolant from OutputVar
Browse files Browse the repository at this point in the history
  • Loading branch information
ph-kev committed Nov 27, 2024
1 parent fedad4e commit 9071cda
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 39 deletions.
13 changes: 13 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
ClimaAnalysis.jl Release Notes
===============================
v0.5.13
-------

## Remove interpolant from OutputVar
With this release, the interpolant (`var.interpolant`) is removed as a field from
`OutputVar`. This was done to prevent the largely unnecessary construction of an interpolant
every time a `OutputVar` is constructed. With this change, a `OutputVar` should take up about
50% less memory.

However, functions like `resampled_as` and interpolating using a `OutputVar` will be slower
as an interpolant must be generated. This means repeated calls to these functions will be
slow than previous versions of ClimaAnalysis.

v0.5.12
-------

Expand Down
37 changes: 23 additions & 14 deletions src/Var.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export OutputVar,
"""
Representing an output variable
"""
struct OutputVar{T <: AbstractArray, A <: AbstractArray, B, C, ITP}
struct OutputVar{T <: AbstractArray, A <: AbstractArray, B, C}

"Attributes associated to this variable, such as short/long name"
attributes::Dict{String, B}
Expand All @@ -84,9 +84,6 @@ struct OutputVar{T <: AbstractArray, A <: AbstractArray, B, C, ITP}

"Array that maps name array index to the dimension name"
index2dim::Vector{String}

"Interpolant from Interpolations.jl, used to evaluate the OutputVar onto any given point."
interpolant::ITP
end

"""
Expand Down Expand Up @@ -184,9 +181,22 @@ function OutputVar(attribs, dims, dim_attribs, data)
dim2index =
Dict([dim_name => index for (index, dim_name) in enumerate(keys(dims))])

# TODO: Make this lazy: we should compute the spline the first time we use
# it, not when we create the object
itp = _make_interpolant(dims, data)
# Check if the size of data matches with the size of dims
if !(
isempty(dims) ||
any(d -> ndims(d) != 1 || length(d) == 1, values(dims))
)
size_data = size(data)
for (dim_index, (dim_name, dim_array)) in enumerate(dims)
dim_length = length(dim_array)
data_length = size_data[dim_index]
if dim_length != data_length
error(
"Dimension $dim_name has inconsistent size with provided data ($dim_length != $data_length)",
)
end
end
end

function _maybe_process_key_value(k, v)
k != "units" && return k => v
Expand All @@ -207,7 +217,6 @@ function OutputVar(attribs, dims, dim_attribs, data)
data,
dim2index,
index2dim,
itp,
)
end

Expand Down Expand Up @@ -1020,10 +1029,8 @@ julia> var2d = ClimaAnalysis.OutputVar(Dict("time" => time, "z" => z), data); va
```
"""
function (x::OutputVar)(target_coord)
isnothing(x.interpolant) && error(
"Splines cannot be constructed because one (or more) of the dimensions of variable is not 1D",
)
return x.interpolant(target_coord...)
itp = _make_interpolant(x.dims, x.data)
return itp(target_coord...)
end

"""
Expand Down Expand Up @@ -1160,8 +1167,9 @@ function resampled_as(src_var::OutputVar, dest_var::OutputVar)
src_var = reordered_as(src_var, dest_var)
_check_dims_consistent(src_var, dest_var)

itp = _make_interpolant(src_var.dims, src_var.data)
src_resampled_data =
[src_var(pt) for pt in Base.product(values(dest_var.dims)...)]
[itp(pt...) for pt in Base.product(values(dest_var.dims)...)]

# Construct new OutputVar to return
src_var_ret_dims = empty(src_var.dims)
Expand Down Expand Up @@ -1772,9 +1780,10 @@ function make_lonlat_mask(

# Resample so that the mask match up with the grid of var
# Round because linear resampling is done and we want the mask to be only ones and zeros
intp = _make_interpolant(mask_var.dims, mask_var.data)
mask_arr =
[
mask_var(pt) for pt in Base.product(
intp(pt...) for pt in Base.product(
input_var.dims[longitude_name(input_var)],
input_var.dims[latitude_name(input_var)],
)
Expand Down
37 changes: 12 additions & 25 deletions test/test_Var.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,39 +94,33 @@ end
time = 1.0:100 |> collect
data = ones(length(lon), length(lat), length(time))
dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time])
attribs = Dict("long_name" => "hi")
dim_attribs = OrderedDict([
"long" => Dict("units" => "test_units1"),
"lat" => Dict("units" => "test_units2"),
"time" => Dict("units" => "test_units3"),
])
var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)
@test var.interpolant.et == (Intp.Periodic(), Intp.Flat(), Intp.Throw())
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test intp.et == (Intp.Periodic(), Intp.Flat(), Intp.Throw())

# Not equispaced for lon and lat
lon = 0.5:1.0:359.5 |> collect |> x -> push!(x, 42.0) |> sort
lat = -89.5:1.0:89.5 |> collect |> x -> push!(x, 42.0) |> sort
time = 1.0:100 |> collect
data = ones(length(lon), length(lat), length(time))
dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time])
var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)
@test var.interpolant.et == (Intp.Throw(), Intp.Throw(), Intp.Throw())
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test intp.et == (Intp.Throw(), Intp.Throw(), Intp.Throw())

# Does not span entire range for and lat
lon = 0.5:1.0:350.5 |> collect
lat = -89.5:1.0:80.5 |> collect
time = 1.0:100 |> collect
data = ones(length(lon), length(lat), length(time))
dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time])
var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)
@test var.interpolant.et == (Intp.Throw(), Intp.Throw(), Intp.Throw())
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test intp.et == (Intp.Throw(), Intp.Throw(), Intp.Throw())

# Lon is exactly 360 degrees
lon = 0.0:1.0:360.0 |> collect
data = ones(length(lon))
dims = OrderedDict(["lon" => lon])
var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)
@test var.interpolant.et == (Intp.Periodic(),)
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test intp.et == (Intp.Periodic(),)

# Dates for the time dimension
lon = 0.5:1.0:359.5 |> collect
Expand All @@ -138,14 +132,8 @@ end
]
data = ones(length(lon), length(lat), length(time))
dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time])
attribs = Dict("long_name" => "hi")
dim_attribs = OrderedDict([
"long" => Dict("units" => "test_units1"),
"lat" => Dict("units" => "test_units2"),
"time" => Dict("units" => "test_units3"),
])
var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)
@test isnothing(var.interpolant)
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test isnothing(intp)
end

@testset "empty" begin
Expand Down Expand Up @@ -550,7 +538,7 @@ end
end

@testset "Long name updates" begin
# Setup to test x_avg, y_avg, xy_avg
# Setup to test x_avg, y_avg, xy_avg
x = 0.0:180.0 |> collect
y = 0.0:90.0 |> collect
time = 0.0:10.0 |> collect
Expand Down Expand Up @@ -1894,8 +1882,7 @@ end
attribs = Dict("long_name" => "hi")
dim_attribs = OrderedDict(["lon" => Dict("units" => "deg")])
var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)

@test isnothing(var.interpolant)
@test isnothing(ClimaAnalysis.Var._make_interpolant(dims, data))

reverse_var = ClimaAnalysis.reverse_dim(var, "lat")
@test reverse(lat) == reverse_var.dims["lat"]
Expand Down

0 comments on commit 9071cda

Please sign in to comment.