Skip to content

Commit

Permalink
Merge #2259
Browse files Browse the repository at this point in the history
2259: Clean up: remove underused quantities r=Sbozzolo a=Sbozzolo

Peel off #2244.

Co-authored-by: Gabriele Bozzola <[email protected]>
  • Loading branch information
bors[bot] and Sbozzolo authored Oct 23, 2023
2 parents 6994c4b + 93aa09a commit 28e9b42
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 29 deletions.
3 changes: 2 additions & 1 deletion perf/benchmark_dump.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ for h_elem in 8:8:40

@info "Running benchmark_step for h_elem=$h_elem"
n_steps = 10
device = ClimaComms.device(integrator.p.comms_ctx)
comms_ctx = ClimaComms.context(integrator.u.c)
device = ClimaComms.device(comms_ctx)
if device isa ClimaComms.CUDADevice
e = CUDA.@elapsed begin
s = CA.@timed_str begin
Expand Down
3 changes: 2 additions & 1 deletion perf/benchmark_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ CA.benchmark_step!(integrator, Y₀); # compile first

@info "Running benchmark_step!..."
n_steps = 10
device = ClimaComms.device(integrator.p.comms_ctx)
comms_ctx = ClimaComms.context(integrator.u.c)
device = ClimaComms.device(comms_ctx)
if device isa ClimaComms.CUDADevice
e = CUDA.@elapsed begin
s = CA.@timed_str begin
Expand Down
4 changes: 0 additions & 4 deletions src/cache/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ function default_cache(
numerics
(; apply_limiter) = numerics
ᶜcoord = Fields.local_geometry_field(Y.c).coordinates
ᶠcoord = Fields.local_geometry_field(Y.f).coordinates
R_d = FT(CAP.R_d(params))
MSLP = FT(CAP.MSLP(params))
grav = FT(CAP.grav(params))
Expand Down Expand Up @@ -82,11 +81,8 @@ function default_cache(
is_init = Ref(true),
simulation,
atmos,
comms_ctx = ClimaComms.context(axes(Y.c)),
sfc_setup = surface_setup(params),
test,
moisture_model = atmos.moisture_model,
model_config = atmos.model_config,
limiter,
ᶜΦ,
ᶠgradᵥ_ᶜΦ = ᶠgradᵥ.(ᶜΦ),
Expand Down
6 changes: 4 additions & 2 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,8 @@ function save_to_disk_func(integrator)
sec = floor(Int, t % (60 * 60 * 24))
@info "Saving diagnostics to HDF5 file on day $day second $sec"
output_file = joinpath(output_dir, "day$day.$sec.hdf5")
hdfwriter = InputOutput.HDF5Writer(output_file, p.comms_ctx)
comms_ctx = ClimaComms.context(integrator.u.c)
hdfwriter = InputOutput.HDF5Writer(output_file, comms_ctx)
InputOutput.HDF5.write_attribute(hdfwriter.file, "time", t) # TODO: a better way to write metadata
InputOutput.write!(hdfwriter, Y, "Y")
FT = Spaces.undertype(axes(Y.c))
Expand All @@ -509,7 +510,8 @@ function save_restart_func(integrator)
@info "Saving restart file to HDF5 file on day $day second $sec"
mkpath(joinpath(output_dir, "restart"))
output_file = joinpath(output_dir, "restart", "day$day.$sec.hdf5")
hdfwriter = InputOutput.HDF5Writer(output_file, integrator.p.comms_ctx)
comms_ctx = ClimaComms.context(integrator.u.c)
hdfwriter = InputOutput.HDF5Writer(output_file, comms_ctx)
InputOutput.HDF5.write_attribute(hdfwriter.file, "time", t) # TODO: a better way to write metadata
InputOutput.write!(hdfwriter, Y, "Y")
Base.close(hdfwriter)
Expand Down
4 changes: 3 additions & 1 deletion src/diagnostics/writers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This file defines function-generating functions for output_writers for diagnostics. The
# writers come with opinionated defaults.

import ClimaComms
import ClimaCore.Remapping: interpolate_array
import NCDatasets

Expand Down Expand Up @@ -45,7 +46,8 @@ function HDF5Writer()
"$(diagnostic.output_short_name)_$(time).h5",
)

hdfwriter = InputOutput.HDF5Writer(output_path, integrator.p.comms_ctx)
comms_ctx = ClimaComms.context(integrator.u.c)
hdfwriter = InputOutput.HDF5Writer(output_path, comms_ctx)
InputOutput.write!(hdfwriter, value, "$(diagnostic.output_short_name)")
attributes = Dict(
"time" => time,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ function non_orographic_gravity_wave_tendency!(
::NonOrographyGravityWave,
)
#unpack
(; ᶜts, ᶜT, ᶜdTdz, ᶜbuoyancy_frequency, params, model_config) = p
(; ᶜts, ᶜT, ᶜdTdz, ᶜbuoyancy_frequency, params) = p
(; model_config) = p.atmos
(;
gw_source_ampl,
gw_Bw,
Expand Down
14 changes: 2 additions & 12 deletions src/prognostic_equations/cloud_fraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,10 @@ function compute_cloud_fraction(
θl′θl′,
θl′qt′,
)
return quad_loop(
env_thermo_quad,
env_thermo_quad.quadrature_type,
vars,
thermo_params,
)::FT
return quad_loop(env_thermo_quad, vars, thermo_params)::FT
end

function quad_loop(
env_thermo_quad::SGSQuadrature,
quadrature_type::GaussianQuad,
vars,
thermo_params,
)
function quad_loop(env_thermo_quad::SGSQuadrature, vars, thermo_params)

# qt - total water specific humidity
# θl - liquid ice potential temperature
Expand Down
2 changes: 1 addition & 1 deletion src/prognostic_equations/forcing/subsidence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function subsidence_cache(Y, subsidence::Subsidence)
end

function subsidence_tendency!(Yₜ, Y, p, t, colidx, ::Subsidence)
(; moisture_model) = p
moisture_model = p.atmos.moisture_model
subsidence_profile = p.subsidence.prof
ᶜ∇MSE_gm = p.ᶜ∇MSE_gm[colidx]
ᶜsubsidence = p.ᶜsubsidence[colidx]
Expand Down
17 changes: 11 additions & 6 deletions src/solver/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ thermo_state_type(::NonEquilMoistModel, ::Type{FT}) where {FT} =
TD.PhaseNonEquil{FT}


function get_callbacks(parsed_args, simulation, atmos, params)
function get_callbacks(parsed_args, simulation, atmos, params, comms_ctx)
FT = eltype(params)
(; dt) = simulation

Expand All @@ -488,7 +488,7 @@ function get_callbacks(parsed_args, simulation, atmos, params)
(callbacks..., call_every_dt(save_restart_func, dt_save_restart))
end

if is_distributed(simulation.comms_ctx)
if is_distributed(comms_ctx)
callbacks = (
callbacks...,
call_every_n_steps(
Expand Down Expand Up @@ -566,7 +566,7 @@ function get_cache(
)
end

function get_simulation(config::AtmosConfig, comms_ctx)
function get_simulation(config::AtmosConfig)
(; parsed_args) = config
FT = eltype(config)

Expand All @@ -581,7 +581,6 @@ function get_simulation(config::AtmosConfig, comms_ctx)
mkpath(output_dir)

sim = (;
comms_ctx,
is_debugging_tc = parsed_args["debugging_tc"],
output_dir,
restart = haskey(ENV, "RESTART_FILE"),
Expand Down Expand Up @@ -783,7 +782,7 @@ function get_integrator(config::AtmosConfig)

atmos = get_atmos(config, params)
numerics = get_numerics(config.parsed_args)
simulation = get_simulation(config, config.comms_ctx)
simulation = get_simulation(config)
if config.parsed_args["log_params"]
filepath = joinpath(simulation.output_dir, "$(job_id)_parameters.toml")
CP.log_parameter_information(config.toml_dict, filepath)
Expand Down Expand Up @@ -834,7 +833,13 @@ function get_integrator(config::AtmosConfig)
@info "ode_configuration: $s"

s = @timed_str begin
callback = get_callbacks(config.parsed_args, simulation, atmos, params)
callback = get_callbacks(
config.parsed_args,
simulation,
atmos,
params,
config.comms_ctx,
)
end
@info "get_callbacks: $s"

Expand Down

0 comments on commit 28e9b42

Please sign in to comment.