From dd3c8f753c647c149f96755e861841e9ca9bf7cd Mon Sep 17 00:00:00 2001 From: Kevin Phan <98072684+ph-kev@users.noreply.github.com> Date: Wed, 4 Dec 2024 13:49:17 -0800 Subject: [PATCH] Add interpolation routine This commit adds an interpolation routine for use in ClimaAnalysis, which seeks to replace Interpolations.jl in Var.jl. The interpolation routine supports N-dimensional linear interpolation on a grid with throw, flat, or periodic boundary conditions. Compared against Interpolations.jl, the interpolation routine does not make a struct and does not allocate anything on the heap when interpolating a point. --- src/Numerics.jl | 244 +++++++++++++++++++++++++++++++++++++ test/test_Numerics.jl | 271 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 515 insertions(+) diff --git a/src/Numerics.jl b/src/Numerics.jl index a04e3c76..4febaf13 100644 --- a/src/Numerics.jl +++ b/src/Numerics.jl @@ -166,4 +166,248 @@ function _integration_weights_generic_equispaced(arr) return fill(arr[begin + 1] - arr[begin]) end +#= +Why not use Interpolations.jl? + +One of the most expensive operation in terms of time and memory is +`Var.resampled_as(src_var, dest_var)` which uses an interpolation to resample `data` in +`src_var` to match the `dims` in `dest_var`. However, interpolating using Interpolations.jl +is expensive. One must first intialize an interpolant, whose size in memory is at least as +big as the memory of `src_var.data`. Then, each evaluation using the interpolant +costs some amount of allocation on the heap. As a result, interpolating is extremely slow +for a large number of points. + +To solve this, we write our interpolation routine in the Numerics module that support the following: +1. No extra dependencies +2. No allocation when evaluating a point +3. Support boundary conditions for periodic, flat, and throw on an irregular grid +5. Comparable or better performance to Interpolations.jl +=# + +""" + linear_interpolate(point::NTuple{N, FT1}, + axes::NTuple{N, Vector}, + data::AbstractArray{FT2, N}, + extp_conds::NTuple{N, NTuple{4, Function}}) where {N, FT1, FT2} + +Linear interpolate `data` on `axes` and return the value at `point`. Extrapolation is +handled by `extp_conds`. +""" +function linear_interpolate( + point::NTuple{N}, + axes, + data::AbstractArray{FT2, N}, + extp_conds, +) where {N, FT2} + # Get a new point as determined by the extrapolation condition + point = extp_to_point(point, axes, extp_conds) + + # Find which cell contain the point + cell_indices_for_axes = find_cell_indices_for_axes(point, axes) + val = zero(FT2) + + # Compute the denominator of the formula for linear interpolation + # (in 1D, this is x_1 - x_0) + bottom_term = compute_bottom_term(axes, cell_indices_for_axes) + + # Iterate through all 2^N points + @inbounds for bits in 0:(2^N - 1) + term = one(FT2) + bound_indices = get_indices(cell_indices_for_axes, bits) + sign = get_sign(cell_indices_for_axes, bits) + # Weight is the value at each of the points of the cell + weight = data[get_complement_indices(cell_indices_for_axes, bits)...] + @inbounds for (dim_idx, bound_idx) in pairs(bound_indices) + val_minus_x2_or_x1 = point[dim_idx] - axes[dim_idx][bound_idx] + term *= val_minus_x2_or_x1 + end + term *= sign * weight + val += term + end + return val / bottom_term +end + +function linear_interpolate( + point::Number, + axes, + data::AbstractArray{FT2, N}, + extp_conds, +) where {N, FT2} + point = Tuple(point...) + return linear_interpolate(point, axes, data, extp_conds) +end + +function linear_interpolate( + point::AbstractVector, + axes, + data::AbstractArray{FT2, N}, + extp_conds, +) where {N, FT2} + point = Tuple(coord for coord in point) + return linear_interpolate(point, axes, data, extp_conds) +end + +function linear_interpolate( + point::Tuple, + axes, + data::AbstractArray{FT2, N}, + extp_conds, +) where {N, FT2} + point = promote(point...) + return linear_interpolate(point, axes, data, extp_conds) +end + + +""" + compute_bottom_term(axes, cell_indices_for_axes::NTuple{N}) + +Compute the bottom term when computing in linear interpolation. + +Consider the formula for 1D linear interpolation which is + y_0 * ((x_1 - x) / (x_1 - x_0)) + y_1 * ((x - x_0) / (x_1 - x_0)) +for interpolating the point (x, y) on line between (x_0, y_0) and (x_1, y_1). This function +computes x_1 - x_0. +""" +function compute_bottom_term(axes, cell_indices_for_axes::NTuple{N}) where {N} + return reduce( + *, + ntuple( + dim_idx -> + axes[dim_idx][cell_indices_for_axes[dim_idx][end]] - + axes[dim_idx][cell_indices_for_axes[dim_idx][begin]], + N, + ), + ) +end + +""" + get_complement_indices(indices::NTuple{N, Tuple{I, I}}, bits) where {N, I} + +Given a tuple consisting of 2-tuple, return a tuple of one element from each tuple according +to bits. The elements in the tuple are the complement of those found by `get_indices`. +""" +function get_complement_indices(indices::NTuple{N}, bits) where {N} + return ntuple(dim -> if (bits & (1 << (dim - 1)) != 0) + indices[dim][2] + else + indices[dim][1] + end, N) +end + +""" + get_indices(indices::NTuple{N, Tuple{I, I}}, bits) where {N, I} + +Given a tuple consisting of 2-tuple, return a tuple of one element from each tuple according to bits. +""" +function get_indices(indices::NTuple{N}, bits) where {N} + return ntuple(dim -> if (bits & (1 << (dim - 1)) != 0) + indices[dim][1] + else + indices[dim][2] + end, N) +end + +""" + get_sign(indices::NTuple{N, Tuple{I, I}}, bits) where {N, I} + +Given a tuple consisting of 2-tuple, compute the appropriate sign when interpolating. +""" +function get_sign(_indices::NTuple{N}, bits) where {N} + return reduce(*, ntuple(dim -> if (bits & (1 << (dim - 1)) != 0) + 1 + else + -1 + end, N)) +end + +""" + find_cell_indices_for_axes(point::NTuple{N, FT}, + axes::NTuple{N, A}) where {N, FT, A <:AbstractVector} + +Given a point and axes, find the indices of the N-dimensional hyperrectangle, where the +point lives in. +""" +function find_cell_indices_for_axes(point::NTuple{N}, axes) where {N} + return ntuple( + dim_idx -> find_cell_indices_for_ax(point[dim_idx], axes[dim_idx]), + N, + ) +end + +""" + find_cell_indices_for_ax(val::FT1, ax::AbstractVector{FT2}) where {FT1, FT2} + +Given `val` and an `ax`, find the indices of the cell, where `val` lives in. +""" +function find_cell_indices_for_ax( + val::FT1, + ax::AbstractVector{FT2}, +) where {FT1, FT2} + len_of_ax = length(ax) + (val == ax[begin]) && return (1, 2) + (val == ax[end]) && return (len_of_ax - 1, len_of_ax) + idx = searchsortedfirst(ax, val) + return (idx - 1, idx) +end + +""" + extp_to_point(point::NTuple{N, FT1}, axes::NTuple{N, Vector}, extp_conds) where {N, FT1} + +Return a new point to evaluate at according to the extrapolation conditions. +""" +function extp_to_point(point::NTuple{N}, axes, extp_conds) where {N} + return ntuple( + idx -> extp_conds[idx].get_val_for_point(point[idx], axes[idx]), + N, + ) +end + +""" + extp_cond_throw() + +Return extrapolation condition for throwing an error when the point is outside of bounds. + +The first and last nodes are not co-located. For example, if the axis is [1.0, 2.0, 3.0] +and the data is [4.0, 5.0, 6.0], then the value at 3.0 is 6.0 and not 4.0. +""" +function extp_cond_throw() + get_val_for_point(val, ax) = begin + val < ax[begin] && return error("Out of bounds error with $val in $ax") + val > ax[end] && return error("Out of bounds error with $val in $ax") + return val + end + return (; get_val_for_point = get_val_for_point) +end + +""" + extp_cond_flat() + +Return flat extrapolation condition. +""" +function extp_cond_flat() + get_val_for_point(val, ax) = begin + val < ax[begin] && return typeof(val)(ax[begin]) + val > ax[end] && return typeof(val)(ax[end]) + return val + end + return (; get_val_for_point) +end + +""" + extp_cond_periodic() + +Return periodic extrapolation condtion. +""" +function extp_cond_periodic() + get_val_for_point(val, ax) = begin + if (val < ax[begin]) || (val > ax[end]) + width = ax[end] - ax[begin] + new_val = mod(val - ax[begin], width) + ax[begin] + return typeof(val)(new_val) + end + return val + end + return (; get_val_for_point) +end + end diff --git a/test/test_Numerics.jl b/test/test_Numerics.jl index 4e38940a..59a1a896 100644 --- a/test/test_Numerics.jl +++ b/test/test_Numerics.jl @@ -165,3 +165,274 @@ end 2.5, ) end + +@testset "Get indices and sign" begin + indices = ((1, 2),) + @test ClimaAnalysis.Numerics.get_indices(indices, 0) == (2,) + @test ClimaAnalysis.Numerics.get_indices(indices, 1) == (1,) + @test ClimaAnalysis.Numerics.get_complement_indices(indices, 0) == (1,) + @test ClimaAnalysis.Numerics.get_complement_indices(indices, 1) == (2,) + @test ClimaAnalysis.Numerics.get_sign(indices, 0) == -1 + @test ClimaAnalysis.Numerics.get_sign(indices, 1) == 1 + + indices = ((1, 2), (3, 4)) + @test ClimaAnalysis.Numerics.get_indices(indices, 0) == (2, 4) + @test ClimaAnalysis.Numerics.get_indices(indices, 1) == (1, 4) + @test ClimaAnalysis.Numerics.get_indices(indices, 2) == (2, 3) + @test ClimaAnalysis.Numerics.get_indices(indices, 3) == (1, 3) + @test ClimaAnalysis.Numerics.get_complement_indices(indices, 0) == (1, 3) + @test ClimaAnalysis.Numerics.get_complement_indices(indices, 1) == (2, 3) + @test ClimaAnalysis.Numerics.get_complement_indices(indices, 2) == (1, 4) + @test ClimaAnalysis.Numerics.get_complement_indices(indices, 3) == (2, 4) + @test ClimaAnalysis.Numerics.get_sign(indices, 0) == 1 + @test ClimaAnalysis.Numerics.get_sign(indices, 1) == -1 + @test ClimaAnalysis.Numerics.get_sign(indices, 2) == -1 + @test ClimaAnalysis.Numerics.get_sign(indices, 3) == 1 +end + +@testset "Find indices for cell" begin + val1 = 5 + ax1 = [0, 10] + @test ClimaAnalysis.Numerics.find_cell_indices_for_ax(val1, ax1) == (1, 2) + + val2 = 6 + ax2 = [0, 4, 10] + @test ClimaAnalysis.Numerics.find_cell_indices_for_ax(val2, ax2) == (2, 3) + + @test ClimaAnalysis.Numerics.find_cell_indices_for_axes( + (val1, val2), + (ax1, ax2), + ) == ((1, 2), (2, 3)) +end + +@testset "Extrapolation conditions" begin + throw = ClimaAnalysis.Numerics.extp_cond_throw() + flat = ClimaAnalysis.Numerics.extp_cond_flat() + periodic = ClimaAnalysis.Numerics.extp_cond_periodic() + + ax = [0, 1, 2, 3] + @test_throws ErrorException throw.get_val_for_point(10, ax) + @test_throws ErrorException throw.get_val_for_point(-1, ax) + @test throw.get_val_for_point(1.5, ax) == 1.5 + + @test flat.get_val_for_point(10, ax) == 3 + @test flat.get_val_for_point(-1, ax) == 0 + @test flat.get_val_for_point(1.5, ax) == 1.5 + + @test periodic.get_val_for_point(10, ax) == 1 + @test periodic.get_val_for_point(-1, ax) == 2 + @test periodic.get_val_for_point(1.5, ax) == 1.5 + @test periodic.get_val_for_point(3, ax) == 3 +end + +@testset "Extrapolate to new point" begin + throw = ClimaAnalysis.Numerics.extp_cond_throw() + flat = ClimaAnalysis.Numerics.extp_cond_flat() + periodic = ClimaAnalysis.Numerics.extp_cond_periodic() + + ax1 = [0, 1, 2, 3] + @test ClimaAnalysis.Numerics.extp_to_point((1,), (ax1,), (throw,)) == (1,) + + ax2 = [4, 5, 6, 7] + @test ClimaAnalysis.Numerics.extp_to_point( + (-1, 8), + (ax1, ax2), + (flat, periodic), + ) == (0, 5) +end + +@testset "Interpolation" begin + throw = ClimaAnalysis.Numerics.extp_cond_throw() + flat = ClimaAnalysis.Numerics.extp_cond_flat() + periodic = ClimaAnalysis.Numerics.extp_cond_periodic() + + # 1D case + axes = ([1.0, 2.0, 3.0],) + data = [3.0, 1.0, 0.0] + + @test ClimaAnalysis.Numerics.linear_interpolate( + (1.0,), + axes, + data, + (throw,), + ) == 3.0 + @test ClimaAnalysis.Numerics.linear_interpolate( + (3.0,), + axes, + data, + (throw,), + ) == 0.0 + @test ClimaAnalysis.Numerics.linear_interpolate( + (1.5,), + axes, + data, + (throw,), + ) == 2.0 + + # 1D case with extrapolation conditions + @test_throws ErrorException ClimaAnalysis.Numerics.linear_interpolate( + (0.0,), + axes, + data, + (throw,), + ) + @test_throws ErrorException ClimaAnalysis.Numerics.linear_interpolate( + (4.0,), + axes, + data, + (throw,), + ) + @test ClimaAnalysis.Numerics.linear_interpolate( + (0.0,), + axes, + data, + (flat,), + ) == 3.0 + @test ClimaAnalysis.Numerics.linear_interpolate( + (4.0,), + axes, + data, + (flat,), + ) == 0.0 + @test ClimaAnalysis.Numerics.linear_interpolate( + (0.0,), + axes, + data, + (periodic,), + ) == 1.0 + @test ClimaAnalysis.Numerics.linear_interpolate( + (4.0,), + axes, + data, + (periodic,), + ) == 1.0 + + # 2D case + axes = ([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]) + data = reshape(1:9, (3, 3)) + + @test ClimaAnalysis.Numerics.linear_interpolate( + (1.0, 4.0), + axes, + data, + (throw, throw), + ) == 1.0 + @test ClimaAnalysis.Numerics.linear_interpolate( + (3.0, 6.0), + axes, + data, + (throw, throw), + ) == 9.0 + @test ClimaAnalysis.Numerics.linear_interpolate( + (2.0, 5.0), + axes, + data, + (throw, throw), + ) == 5.0 + @test ClimaAnalysis.Numerics.linear_interpolate( + (1.5, 4.5), + axes, + data, + (throw, throw), + ) == 3.0 + @test ClimaAnalysis.Numerics.linear_interpolate( + (1.5, 5.5), + axes, + data, + (throw, throw), + ) == 6.0 + + # 2D cases with extrapolation conditions + @test_throws ErrorException ClimaAnalysis.Numerics.linear_interpolate( + (4.0, 5.0), + axes, + data, + (throw, flat), + ) + @test_throws ErrorException ClimaAnalysis.Numerics.linear_interpolate( + (2.0, 7.0), + axes, + data, + (flat, throw), + ) + @test_throws ErrorException ClimaAnalysis.Numerics.linear_interpolate( + (0.0, 8.0), + axes, + data, + (throw, throw), + ) + @test ClimaAnalysis.Numerics.linear_interpolate( + (0.0, 8.0), + axes, + data, + (flat, flat), + ) == 7.0 + @test ClimaAnalysis.Numerics.linear_interpolate( + (4.0, 7.0), + axes, + data, + (periodic, periodic), + ) == 5.0 + @test ClimaAnalysis.Numerics.linear_interpolate( + (3.0, 6.0), + axes, + data, + (periodic, periodic), + ) == 9.0 + @test ClimaAnalysis.Numerics.linear_interpolate( + (4.0, 7.0), + axes, + data, + (flat, periodic), + ) == 6.0 + + # 3D cases with extrapolation conditions + axes = ([1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]) + data = reshape(1:27, (3, 3, 3)) + @test ClimaAnalysis.Numerics.linear_interpolate( + (1.0, 5.0, 7.0), + axes, + data, + (throw, throw, throw), + ) == 4.0 + @test ClimaAnalysis.Numerics.linear_interpolate( + (1.5, 5.2, 7.5), + axes, + data, + (throw, throw, throw), + ) ≈ 9.6 + + # Different types + # Axes have different types and inputs have different types + axes = ([1.0f0, 2.0f0], [Float16(3.0), Float16(4.0)]) + data = [[1.0, 2.0] [3.0, 4.0]] + @test ClimaAnalysis.Numerics.linear_interpolate( + (1.5f0, 3.5), + axes, + data, + (throw, throw), + ) == 2.5 + @test ClimaAnalysis.Numerics.linear_interpolate( + (1.5, 4.5f0), + axes, + data, + (flat, flat), + ) == 3.5 + @test ClimaAnalysis.Numerics.linear_interpolate( + [1.5, 4.5f0], + axes, + data, + (flat, flat), + ) == 3.5 + + # Single number + axes = ([1.0, 2.0, 3.0],) + data = [3.0, 1.0, 0.0] + + @test ClimaAnalysis.Numerics.linear_interpolate( + 1.0, + axes, + data, + (throw,), + ) == 3.0 +end