Skip to content

Commit

Permalink
Extend tests for restarts
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed Sep 19, 2024
1 parent 448d485 commit 139029b
Showing 1 changed file with 197 additions and 129 deletions.
326 changes: 197 additions & 129 deletions test/restart.jl
Original file line number Diff line number Diff line change
@@ -1,69 +1,108 @@
import ClimaAtmos as CA
import ClimaCore: Fields, Geometry
import ClimaCore
import ClimaCore: DataLayouts, Fields, Geometry
import ClimaComms
pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends
import Logging
using Test

function compare(
one,
two;
name = "",
ignore = [
:scratch,
:ghost_buffer,
:output_dir,
:hyperdiffusion_ghost_buffer,
],
)
# This test checks that:

# 1. A simulation, saved to a checkpoint, is read back identically (up to some
# tolerance and excluding those fields that are computed during the
# calculation of the tendencies)
# 2. A simulation, saved to a previous checkpoint, and read back and evolved to
# the same time is identical (up to some tolerance)
# 3. ClimaAtmos can automatically detect restarts
# 4. It is not possible to change the Atmos model across simulations

"""
compare(one, two; name = "", ignore = Set([]))
Recursively compare `one` and `two` up to some numeric tolerance.
`compare` walks through all the properties in `one` and `two` until it finds
that there are no more properties. At that point, `compare` tries to match the
resulting objects. When such objects are arrays with floating point, `compare`
defines a notion of `error` that is the following: when the absolute value is
less than `100eps(eltype)`, `error = absolute_error`, otherwise it is relative
error. The `error` is then compared against a tolerance.
Keyword arguments
=================
- `name` is used to collect the name of the property while we go recursively
over all the properties. You can pass a base name.
- `ignore` is a collection of `Symbol`s that identify properties that are
ignored when walking through the tree. This is useful for properties that
are known to be different (e.g., `output_dir`).
"""
function compare(one, two; name = "", ignore = Set([]))
propertynames(one) == propertynames(two) ||
error("Cannot compare these objects")

properties = filter(x -> !(x in ignore), propertynames(one))

if isempty(properties)
# Base case

@test typeof(one) == typeof(two)
# Objects that we know how to compare
if one isa Fields.Field || one isa Array || one isa Number
if one isa Fields.Field ||
one isa Array ||
one isa DataLayouts.AbstractData
if eltype(one) <: Geometry.AxisTensor
# We have to handle the AxisTensor case separately because it
# behaves differently
for (f, g) in zip(one, two)
compare(f, g; name = "$(name)")
end
elseif eltype(one) <: CartesianIndex
@test one == two
else
one isa Number && (one = [one])
two isa Number && (two = [two])

arr1 = Array(parent(one))
arr2 = Array(parent(two))

# Calculate element-wise relative difference, avoiding division by zero
diff = abs.(arr1 .- arr2)
denominator = abs.(arr1)
relative_diff =
ifelse.(
denominator .> 0,
diff ./ denominator,
ifelse.(diff .== 0, 0.0, Inf),
)

if !isempty(relative_diff)
# Check if the max relative differences is within tolerance
max_error = maximum(relative_diff)
println("Relative error in $name: ", max_error)
else
println("$name is empty on both sides")
end
if eltype(arr1) <: AbstractFloat
@test max_error < 100eps(eltype(arr1))

# We compute the error in this way:
# - when the absolute value is larger than ABS_TOL, we use the
# absolute error
# - in the other cases, we compare the relative errors

ABS_TOL = 100eps(eltype(arr1))
TOL = 100eps(eltype(arr1))

# Calculate element-wise relative difference, avoiding division by zero
diff = abs.(arr1 .- arr2)
denominator = abs.(arr1)
error =
ifelse.(
denominator .> ABS_TOL,
diff ./ denominator,
diff,
)

if !isempty(error)
max_error = maximum(error)
println("$name $(max_error)")
@test max_error <= TOL
else
println("$name is empty on both sides")
end
else
@test max_error 0
@test arr1 == arr2
end
end
elseif Base.issingletontype(typeof(one))
# We have already tested this
elseif one isa AbstractString
elseif one isa AbstractString || one isa Symbol
println("$name $one == $two")
@test one == two
elseif one isa Number
println("$name $one == $two")
# We check with triple equal so that we also catch NaNs being equal
@test one === two
else
println("Cannot compare $name")
@test false
Expand All @@ -84,127 +123,156 @@ end
# Disable all the @info statements that are produced when creating a simulation
Logging.disable_logging(Logging.Info)

### Test Description
# Generate a simulation with some complexity of
# config arguments. Some config combinations are
# incompatible so we do not sweep over all possible
# iterations.

# Modify the timestep to 1-second increments.
# Save simulation state at each timestep,
# and generate a restart file at 0secs, 2secs simulation time.
# Verify objects read in using ClimaCore.InputOutput functions
# are identical (i.e. restarts result
# in the same simulation states as if one were to advance
# the timestepper uninterrupted.)

# TODO: Restart and diagnostic behaviour needs to be
# clearly defined when config files have different
# settings (or when tendency computations conflict with
# dt or t_end parsed args)

# for configuration in ["sphere", "column"]
# for moisture in ["equil"]
# for turb_conv in ["diagnostic_edmfx", "prognostic_edmfx"]
# for precip in ["0M", "1M"]

configuration = "sphere"
moisture = "equil"
turb_conv = "prognostic_edmfx"
precip = "0M"
for bubble in (true, false)

configuration = "sphere"
moisture = "equil"
turb_conv = "prognostic_edmfx"
precip = "0M"

mktempdir() do output_loc
# The `enable_bubble` case is broken for ClimaCore < 0.14.6, so we
# hardcode this to be always false for those versions
pkgversion(ClimaCore) < v"0.14.6" && (bubble = false)

job_id = "restart"
test_dict = Dict(
"test_dycore_consistency" => true, # We will add NaNs to the cache, just to make sure
"check_nan_every" => 3,
"log_progress" => false,
"moist" => moisture,
"precip_model" => precip,
"config" => configuration,
# "turbconv" => turb_conv,
"perturb_initstate" => false,
"dt" => "1secs",
"bubble" => bubble,
# "insolation" => "timevarying",
# "rad" => "clearsky",
# "dt_rad" => "1secs",
# "surface_setup" => "DefaultMoninObukhov",
# "implicit_diffusion" => true,
"t_end" => "3secs",
"dt_save_state_to_disk" => "1secs",
"enable_diagnostics" => false,
"output_dir" => joinpath(output_loc, job_id),
)

println("output_dir: $(test_dict["output_dir"])")

config = CA.AtmosConfig(test_dict, job_id = job_id)

simulation = CA.get_simulation(config)
CA.solve_atmos!(simulation)

# Check re-importing the same state
restart_dir = simulation.output_dir
@test isfile(joinpath(restart_dir), "day0.3.hdf5")

config_should_be_same = CA.AtmosConfig(
merge(test_dict, Dict("detect_restart_file" => true)),
job_id = job_id,
)

simulation_restarted = CA.get_simulation(config_should_be_same)
println(
"Checking integrator for the case where we just read the data",
)
compare(
simulation.integrator.u,
simulation_restarted.integrator.u;
name = "integrator.u",
)
compare(
axes(simulation.integrator.u.c),
axes(simulation_restarted.integrator.u.c);
name = "center_space",
)
compare(
axes(simulation.integrator.u.f),
axes(simulation_restarted.integrator.u.f);
name = "face_space",
)
compare(
simulation.integrator.p,
simulation_restarted.integrator.p;
name = "integrator.p",
ignore = Set([
:scratch,
:output_dir,
# Computed in tendencies (which are not computed in this case)
:ghost_buffer,
:hyperdiff,
:precipitation,
]),
)

# Check re-importing from previous state and advancing one step
restart_file = joinpath(simulation.output_dir, "day0.2.hdf5")
@test isfile(joinpath(restart_dir), "day0.2.hdf5")
# Restart from specific file
config2 = CA.AtmosConfig(
merge(test_dict, Dict("restart_file" => restart_file)),
job_id = job_id,
)

simulation_restarted2 = CA.get_simulation(config2)
CA.fill_with_nans!(simulation_restarted2.integrator.p)

CA.solve_atmos!(simulation_restarted2)
println("Checking integrator.u for the case where we start from 2s")
compare(
simulation.integrator.u,
simulation_restarted2.integrator.u;
name = "integrator.u",
)
compare(
simulation.integrator.p,
simulation_restarted2.integrator.p;
name = "integrator.p",
ignore = Set([:scratch, :output_dir]),
)

# end
# end
# end
# end
end
end
end

@testset "Test incompatible restart" begin
mktempdir() do output_loc
job_id = "restart_$(configuration)_$(moisture)_$(turb_conv)_$(precip)"
job_id = "my_job"
test_dict = Dict(
"check_nan_every" => 3,
"log_progress" => false,
"moist" => moisture,
"precip_model" => precip,
"config" => configuration,
"turbconv" => turb_conv,
"perturb_initstate" => false,
"dt" => "1secs",
"insolation" => "timevarying",
"rad" => "allskywithclear",
"surface_setup" => "DefaultMoninObukhov",
"implicit_diffusion" => true,
"t_end" => "3secs",
"dt_save_state_to_disk" => "1secs",
"enable_diagnostics" => false,
"output_dir" => joinpath(output_loc, job_id),
"enable_diagnostics" => false,
)

println("output_dir: $(test_dict["output_dir"])")

config = CA.AtmosConfig(test_dict, job_id = job_id)

simulation = CA.get_simulation(config)
CA.solve_atmos!(simulation)

# Check re-importing the same state
restart_dir = simulation.output_dir
@test isfile(joinpath(restart_dir), "day0.3.hdf5")

config_should_be_same = CA.AtmosConfig(
merge(test_dict, Dict("detect_restart_file" => true)),
job_id = job_id,
)

simulation_restarted = CA.get_simulation(config_should_be_same)
println("Check file-read from checkpoint data")
println("Checking integrator for the case where we just read the data")
compare(
simulation.integrator.u,
simulation_restarted.integrator.u;
name = "integrator.u",
)
compare(
simulation.integrator.p,
simulation_restarted.integrator.p;
name = "integrator.p",
)

# Check re-importing from previous state and advancing one step
restart_file = joinpath(simulation.output_dir, "day0.2.hdf5")
@test isfile(joinpath(restart_dir), "day0.2.hdf5")
println("Restart from specific file")
config2 = CA.AtmosConfig(
merge(test_dict, Dict("restart_file" => restart_file)),
job_id = job_id,
)

simulation_restarted2 = CA.get_simulation(config2)
println("Advancing restarted simulation")
CA.solve_atmos!(simulation_restarted2)
println("Restarted simulation complete")
println("Checking integrator.u for the case where we start from 2s")
compare(
simulation.integrator.u,
simulation_restarted2.integrator.u;
name = "integrator.u",
)
compare(
simulation.integrator.p,
simulation_restarted2.integrator.p;
name = "integrator.p",
)

# Test that we can catch an Atmos model changing across restarts
config_different = CA.AtmosConfig(
merge(
test_dict,
Dict(
"restart_file" => restart_file,
"output_dir" => joinpath(output_loc, job_id),
"insolation" => "rcemipii",
"detect_restart_file" => true,
),
),
job_id = job_id * "_different",
)
@test_throws ErrorException CA.get_simulation(config_different)

# end
# end
# end
# end
end
end

0 comments on commit 139029b

Please sign in to comment.