Skip to content

Commit

Permalink
Update functions and tests to use Var.remake
Browse files Browse the repository at this point in the history
  • Loading branch information
ph-kev committed Dec 13, 2024
1 parent 0fa2f90 commit bdfed7d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 110 deletions.
51 changes: 18 additions & 33 deletions src/Var.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,6 @@ function reordered_as(src_var::OutputVar, dest_var::OutputVar)
# Reorder dims, dim_attribs, and data, but not attribs
ret_dims = deepcopy(src_var.dims)
ret_dims = OrderedDict(collect(ret_dims)[reorder_indices])
ret_attribs = deepcopy(src_var.attributes)

# Cannot assume that every dimension is present in dim_attribs so we loop to reorder the
# best we can and merge with src_var.dim_attributes to add any remaining pairs to
Expand All @@ -1146,7 +1145,12 @@ function reordered_as(src_var::OutputVar, dest_var::OutputVar)

ret_data = copy(src_var.data)
ret_data = permutedims(ret_data, reorder_indices)
return OutputVar(ret_attribs, ret_dims, ret_dim_attribs, ret_data)
return remake(
src_var,
dims = ret_dims,
data = ret_data,
dim_attributes = ret_dim_attribs,
)
end

"""
Expand All @@ -1172,14 +1176,7 @@ function resampled_as(src_var::OutputVar, dest_var::OutputVar)
for (dim_name, dim_data) in zip(keys(src_var.dims), values(dest_var.dims))
src_var_ret_dims[dim_name] = copy(dim_data)
end
scr_var_ret_attribs = deepcopy(src_var.attributes)
scr_var_ret_dim_attribs = deepcopy(src_var.dim_attributes)
return OutputVar(
scr_var_ret_attribs,
src_var_ret_dims,
scr_var_ret_dim_attribs,
src_resampled_data,
)
return remake(src_var, dims = src_var_ret_dims, data = src_resampled_data)
end

"""
Expand Down Expand Up @@ -1334,10 +1331,8 @@ function split_by_season(var::OutputVar)
return OutputVar(dims, data)
end
ret_dims = deepcopy(var.dims)
ret_attribs = deepcopy(var.attributes)
ret_dim_attribs = deepcopy(var.dim_attributes)
ret_dims[time_name(var)] = time
OutputVar(ret_attribs, ret_dims, ret_dim_attribs, data)
remake(var, dims = ret_dims, data = data)
end
end

Expand Down Expand Up @@ -1633,8 +1628,12 @@ function _dates_to_seconds(
dim_name => dim_data for (dim_name, dim_data) in var_dims
)
ret_dims = OrderedDict(ret_dims_generator...)
ret_data = copy(var.data)
return OutputVar(ret_attribs, ret_dims, ret_dim_attribs, ret_data)
return remake(
var,
attributes = ret_attribs,
dims = ret_dims,
dim_attributes = ret_dim_attribs,
)
end

"""
Expand Down Expand Up @@ -1685,9 +1684,7 @@ function shift_to_start_of_previous_month(var::OutputVar)
ret_attribs["start_date"] = string(start_date)
ret_dims = deepcopy(var.dims)
ret_dims["time"] = time_arr
ret_dim_attributes = deepcopy(var.dim_attributes)
ret_data = copy(var.data)
return OutputVar(ret_attribs, ret_dims, ret_dim_attributes, ret_data)
return remake(var, attributes = ret_attribs, dims = ret_dims)
end

"""
Expand Down Expand Up @@ -1804,11 +1801,7 @@ function make_lonlat_mask(
mask_arr = reshape(mask_arr, size_to_reshape...)
masked_data = input_var.data .* mask_arr

# Remake OutputVar with new data
ret_attribs = deepcopy(input_var.attributes)
ret_dims = deepcopy(input_var.dims)
ret_dim_attributes = deepcopy(input_var.dim_attributes)
return OutputVar(ret_attribs, ret_dims, ret_dim_attributes, masked_data)
return remake(input_var, data = masked_data)
end
end

Expand All @@ -1824,12 +1817,7 @@ you want to use the ocean mask, but there are `NaN`s in the ocean. You can repla
"""
function Base.replace(var::OutputVar, old_new::Pair...)
replaced_data = replace(var.data, old_new...)

# Remake OutputVar with replaced_data
ret_attribs = deepcopy(var.attributes)
ret_dims = deepcopy(var.dims)
ret_dim_attributes = deepcopy(var.dim_attributes)
return OutputVar(ret_attribs, ret_dims, ret_dim_attributes, replaced_data)
return remake(var, data = replaced_data)
end

"""
Expand All @@ -1855,10 +1843,7 @@ function reverse_dim(var::OutputVar, dim_name)
dim_idx = get(var.dim2index, dim_name, dim_name)
ret_data = var.data |> copy |> (A -> reverse(A, dims = dim_idx))

# Remake OutputVar
ret_attribs = deepcopy(var.attributes)
ret_dim_attributes = deepcopy(var.dim_attributes)
return OutputVar(ret_attribs, ret_dims, ret_dim_attributes, ret_data)
return remake(var, dims = ret_dims, data = ret_data)
end

"""
Expand Down
102 changes: 25 additions & 77 deletions test/test_Var.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ import Dates
)
end

@testset "remake" begin
@testset "Remake" begin
lat = collect(range(-89.5, 89.5, 180))
lon = collect(range(-179.5, 179.5, 360))
data = ones(length(lat), length(lon))
Expand Down Expand Up @@ -767,24 +767,12 @@ end
"extra_info" => "hi",
"lat" => Dict("units" => "test_units2"),
])
src_var_one = ClimaAnalysis.OutputVar(
src_attribs,
src_dims,
src_dim_attribs_one,
src_data,
)
src_var_empty = ClimaAnalysis.OutputVar(
src_attribs,
src_dims,
src_dim_attribs_empty,
src_data,
)
src_var_extra = ClimaAnalysis.OutputVar(
src_attribs,
src_dims,
src_dim_attribs_extra,
src_data,
)
src_var_one =
ClimaAnalysis.remake(src_var, dim_attributes = src_dim_attribs_one)
src_var_empty =
ClimaAnalysis.remake(src_var, dim_attributes = src_dim_attribs_empty)
src_var_extra =
ClimaAnalysis.remake(src_var, dim_attributes = src_dim_attribs_extra)
reordered_var = ClimaAnalysis.reordered_as(src_var_one, dest_var)
@test reordered_var.dim_attributes == src_dim_attribs_one
reordered_var = ClimaAnalysis.reordered_as(src_var_empty, dest_var)
Expand Down Expand Up @@ -829,17 +817,7 @@ end
dest_lat = 0.0:45.0 |> collect
dest_data = reshape(1.0:(91 * 46), (91, 46))
dest_dims = OrderedDict(["long" => dest_long, "lat" => dest_lat])
dest_attribs = Dict("long_name" => "hi")
dest_dim_attribs = OrderedDict([
"long" => Dict("units" => "test_units1"),
"lat" => Dict("units" => "test_units2"),
])
dest_var = ClimaAnalysis.OutputVar(
dest_attribs,
dest_dims,
dest_dim_attribs,
dest_data,
)
dest_var = ClimaAnalysis.remake(src_var, data = dest_data, dims = dest_dims)

@test src_var.data == ClimaAnalysis.resampled_as(src_var, src_var).data
resampled_var = ClimaAnalysis.resampled_as(src_var, dest_var)
Expand All @@ -851,23 +829,14 @@ end
src_lat = 45.0:90.0 |> collect
src_data = zeros(length(src_long), length(src_lat))
src_dims = OrderedDict(["long" => src_long, "lat" => src_lat])
src_var = ClimaAnalysis.OutputVar(
src_attribs,
src_dims,
src_dim_attribs,
src_data,
)
src_var = ClimaAnalysis.remake(src_var, dims = src_dims, data = src_data)

dest_long = 85.0:115.0 |> collect
dest_lat = 50.0:85.0 |> collect
dest_data = zeros(length(dest_long), length(dest_lat))
dest_dims = OrderedDict(["long" => dest_long, "lat" => dest_lat])
dest_var = ClimaAnalysis.OutputVar(
dest_attribs,
dest_dims,
dest_dim_attribs,
dest_data,
)
dest_var =
ClimaAnalysis.remake(dest_var, data = dest_data, dims = dest_dims)

@test_throws BoundsError ClimaAnalysis.resampled_as(src_var, dest_var)
end
Expand All @@ -886,18 +855,16 @@ end
dim_attributes,
data,
)
var_without_unitful = ClimaAnalysis.OutputVar(
Dict{String, Any}(),
Dict("long" => long),
dim_attributes,
data,
var_without_unitful = ClimaAnalysis.remake(
var_with_unitful,
attributes = Dict{String, Any}(),
dims = Dict("long" => long),
)

var_empty_unit = ClimaAnalysis.OutputVar(
Dict{String, Any}("units" => ""),
Dict("long" => long),
dim_attributes,
data,
var_empty_unit = ClimaAnalysis.remake(
var_with_unitful,
attributes = Dict{String, Any}("units" => ""),
dims = Dict("long" => long),
)

@test ClimaAnalysis.has_units(var_with_unitful)
Expand Down Expand Up @@ -1605,12 +1572,7 @@ end
lat = collect(range(-89.5, 89.5, 180))
data = ones(length(lon), length(lat))
dims = OrderedDict(["lon" => lon, "lat" => lat])
attribs = Dict("long_name" => "hi")
dim_attribs = OrderedDict([
"lon" => Dict("units" => "deg"),
"lat" => Dict("units" => "deg"),
])
var_lonlat = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)
var_lonlat = ClimaAnalysis.remake(var_latlon, dims = dims, data = data)

land_var_lonlat = ClimaAnalysis.apply_landmask(var_lonlat)
ocean_var_lonlat = ClimaAnalysis.apply_oceanmask(var_lonlat)
Expand Down Expand Up @@ -1834,25 +1796,15 @@ end
@test ClimaAnalysis.dim_units(var, "lat") == "degrees"

# Units do not exist in dim_attribs as a key
lat = collect(range(-89.5, 89.5, 180))
lon = collect(range(-179.5, 179.5, 360))
data = ones(length(lat), length(lon))
dims = OrderedDict(["lat" => lat, "lon" => lon])
attribs = Dict("long_name" => "hi")
dim_attribs =
OrderedDict(["lat" => Dict(), "lon" => Dict("units" => "deg")])
var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)
var = ClimaAnalysis.remake(var, dim_attributes = dim_attribs)
ClimaAnalysis.set_dim_units!(var, "lat", "degrees")
@test ClimaAnalysis.dim_units(var, "lat") == "degrees"

# Dimension is not present in dim_attribs
lat = collect(range(-89.5, 89.5, 180))
lon = collect(range(-179.5, 179.5, 360))
data = ones(length(lat), length(lon))
dims = OrderedDict(["lat" => lat, "lon" => lon])
attribs = Dict("long_name" => "hi")
dim_attribs = OrderedDict(["lon" => Dict("units" => "deg")])
var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)
var = ClimaAnalysis.remake(var, dim_attributes = dim_attribs)
ClimaAnalysis.set_dim_units!(var, "lat", "degrees")
@test ClimaAnalysis.dim_units(var, "lat") == "degrees"

Expand All @@ -1874,7 +1826,7 @@ end
attribs = Dict("long_name" => "hi")
dim_attribs = OrderedDict(["lon" => Dict("units" => "deg")])
var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)
ones_var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data_ones)
ones_var = ClimaAnalysis.remake(var, data = data_ones)

mask_fn = ClimaAnalysis.make_lonlat_mask(
var,
Expand All @@ -1897,19 +1849,15 @@ end
lon = collect(range(-179.5, 179.5, 360))
data = ones(length(lon))
dims = OrderedDict(["lon" => lon])
attribs = Dict("long_name" => "hi")
dim_attribs = OrderedDict(["lon" => Dict("units" => "deg")])
var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)
var = ClimaAnalysis.remake(var, data = data, dims = dims)
@test_throws ErrorException ClimaAnalysis.make_lonlat_mask(var)

lat = collect(range(-89.5, 89.5, 180))
lon = collect(range(-179.5, 179.5, 360))
t = collect(range(1, 2, 2))
data = ones(length(lat), length(lon), length(t))
dims = OrderedDict(["lat" => lat, "lon" => lon, "time" => t])
attribs = Dict("long_name" => "hi")
dim_attribs = OrderedDict(["lon" => Dict("units" => "deg")])
var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)
var = ClimaAnalysis.remake(var, data = data, dims = dims)
@test_throws ErrorException ClimaAnalysis.make_lonlat_mask(var)
end

Expand Down

0 comments on commit bdfed7d

Please sign in to comment.