From 6d30ca669e73270df55d7cbd1f6fa61bb82fe6bd Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Tue, 3 Dec 2024 11:20:02 -0800 Subject: [PATCH] Remove WTE from cache This commit removes the `WallTimeEstimate` from the cache and moves it to an isolated place. In the process, I refactored the struct to split reporting with updating, so that reporting can be done with any frequency/schedule desired using the same Schedule infrastructure used by other functions/diagnostics as well. --- Project.toml | 2 +- src/cache/cache.jl | 12 +--- src/callbacks/callback_helpers.jl | 44 +++++++++++++++ src/callbacks/callbacks.jl | 93 +------------------------------ src/callbacks/get_callbacks.jl | 19 ++++--- test/coupler_compatibility.jl | 2 - 6 files changed, 58 insertions(+), 114 deletions(-) diff --git a/Project.toml b/Project.toml index 15ecbe00d66..9056f9907ad 100644 --- a/Project.toml +++ b/Project.toml @@ -45,7 +45,7 @@ ClimaCore = "0.14.12" ClimaDiagnostics = "0.2.4" ClimaParams = "0.10.12" ClimaTimeSteppers = "0.7.33" -ClimaUtilities = "0.1.14" +ClimaUtilities = "0.1.20" CloudMicrophysics = "0.22.3" Dates = "1" DiffEqBase = "6.145" diff --git a/src/cache/cache.jl b/src/cache/cache.jl index a8925fb7beb..7b6d7c88d59 100644 --- a/src/cache/cache.jl +++ b/src/cache/cache.jl @@ -1,7 +1,5 @@ struct AtmosCache{ FT <: AbstractFloat, - FTE, - WTE, SD, AM, NUM, @@ -30,12 +28,6 @@ struct AtmosCache{ """Timestep of the simulation (in seconds). This is also used by callbacks and tendencies""" dt::FT - """End time of the simulation (in seconds). This used by callbacks""" - t_end::FTE - - """Walltime estimate""" - walltime_estimate::WTE - """Start date (used for insolation and for data files).""" start_date::SD @@ -106,7 +98,7 @@ end # The model also depends on f_plane_coriolis_frequency(params) # This is a constant Coriolis frequency that is only used if space is flat function build_cache(Y, atmos, params, surface_setup, sim_info, aerosol_names) - (; dt, t_end, start_date, output_dir) = sim_info + (; dt, start_date, output_dir) = sim_info FT = eltype(params) ᶜcoord = Fields.local_geometry_field(Y.c).coordinates @@ -188,8 +180,6 @@ function build_cache(Y, atmos, params, surface_setup, sim_info, aerosol_names) args = ( dt, - t_end, - WallTimeEstimate(), start_date, atmos, numerics, diff --git a/src/callbacks/callback_helpers.jl b/src/callbacks/callback_helpers.jl index 74bd40f9020..e348761b938 100644 --- a/src/callbacks/callback_helpers.jl +++ b/src/callbacks/callback_helpers.jl @@ -1,4 +1,7 @@ import SciMLBase + +import ClimaDiagnostics.Schedules: AbstractSchedule + ##### ##### Callback helpers ##### @@ -109,3 +112,44 @@ end n_steps_per_cycle_per_cb_diagnostic(cbs) = [callback_frequency(cb).n for cb in cbs if callback_frequency(cb).n > 0] + +import ClimaDiagnostics.Schedules: AbstractSchedule + +import Dates + +""" + CappedGeometricSeriesSchedule(max_steps) + +True every 2^N iterations or every `max_steps`. + +This is useful to have an exponential ramp up of something that saturates to a constant +frequency. (For instance, reporting something more frequently at the beginning of the +simulation, and less frequency later) +""" +struct CappedGeometricSeriesSchedule <: AbstractSchedule + """GeometricSeriesSchedule(integrator) is true every 2^N iterations or every max_steps""" + max_steps::Int + """Last step that this returned true""" + step_last::Base.RefValue{Int} + + function CappedGeometricSeriesSchedule(max_steps; step_last = Ref(0)) + return new(max_steps, step_last) + end +end + +""" + CappedGeometricSeriesSchedule(integrator) + +Returns true if `integrator.step >= last_step + max_steps`, or when `integrator.step` is a +power of 2. `last_step` is the last step this function was true and `max_step` is maximum +allowed interval as defined in the schedule. +""" +function (schedule::CappedGeometricSeriesSchedule)(integrator)::Bool + if isinteger(log2(integrator.step)) || + integrator.step > schedule.step_last[] + schedule.max_steps + schedule.step_last[] = integrator.step + return true + else + return false + end +end diff --git a/src/callbacks/callbacks.jl b/src/callbacks/callbacks.jl index ff9c17a438d..a6704c9fe7e 100644 --- a/src/callbacks/callbacks.jl +++ b/src/callbacks/callbacks.jl @@ -14,6 +14,7 @@ using Insolation: instantaneous_zenith_angle import ClimaCore.Fields: ColumnField import ClimaUtilities.TimeVaryingInputs: evaluate! +import ClimaUtilities.OnlineLogging: WallTimeInfo, report_walltime include("callback_helpers.jl") @@ -363,98 +364,6 @@ NVTX.@annotate function save_state_to_disk_func(integrator, output_dir) return nothing end -Base.@kwdef mutable struct WallTimeEstimate - """Number of calls to the callback""" - n_calls::Int = 0 - """Int indicating next time the callback will print to the log""" - n_next::Int = 1 - """Wall time of previous call to update `WallTimeEstimate`""" - t_wall_last::Float64 = -1 - """Sum of elapsed walltime over calls to `step!`""" - ∑Δt_wall::Float64 = 0 - """Fixed increment to increase n_next by after 5% completion""" - n_fixed_increment::Float64 = -1 -end -import Dates -function print_walltime_estimate(integrator) - (; walltime_estimate, dt, t_end) = integrator.p - t_start = integrator.sol.prob.tspan[1] - wte = walltime_estimate - - # Notes on `ready_to_report` - # - The very first call (when `n_calls == 0`), there's no elapsed - # times to report (and this is called during initialization, - # before `step!` has been called). - # - The second call (`n_calls == 1`) is after `step!` is called - # for the first time, but we don't want to report this since it - # includes compilation time. - # - Calls after that (`n_calls > 1`) exclude compilation and provide - # the best wall time estimates - - ready_to_report = wte.n_calls > 1 - if ready_to_report - # We need to account for skipping cost of `Δt_wall` when `n_calls == 1`: - factor = wte.n_calls == 2 ? 2 : 1 - Δt_wall = factor * (time() - wte.t_wall_last) - else - wte.n_calls == 1 && @info "Progress: Completed first step" - Δt_wall = Float64(0) - wte.n_next = wte.n_calls + 1 - end - wte.∑Δt_wall += Δt_wall - wte.t_wall_last = time() - - if wte.n_calls == wte.n_next && ready_to_report - t = integrator.t - n_steps_total = ceil(Int, (t_end - t_start) / dt) - n_steps = ceil(Int, (t - t_start) / dt) - wall_time_ave_per_step = wte.∑Δt_wall / n_steps - wall_time_ave_per_step_str = time_and_units_str(wall_time_ave_per_step) - percent_complete = round((t - t_start) / t_end * 100; digits = 1) - n_steps_remaining = n_steps_total - n_steps - wall_time_remaining = wall_time_ave_per_step * n_steps_remaining - wall_time_remaining_str = time_and_units_str(wall_time_remaining) - wall_time_total = - time_and_units_str(wall_time_ave_per_step * n_steps_total) - wall_time_spent = time_and_units_str(wte.∑Δt_wall) - simulation_time = time_and_units_str(Float64(t)) - es = EfficiencyStats((t_start, t), wte.∑Δt_wall) - _sypd = simulated_years_per_day(es) - _sypd_str = string(round(_sypd; digits = 3)) - sypd = _sypd_str * if _sypd < 0.01 - sdpd = round(_sypd * 365, digits = 3) - " (sdpd = $sdpd)" - else - "" - end - estimated_finish_date = - Dates.now() + compound_period(wall_time_remaining, Dates.Second) - @info "Progress" simulation_time = simulation_time n_steps_completed = - n_steps wall_time_per_step = wall_time_ave_per_step_str wall_time_total = - wall_time_total wall_time_remaining = wall_time_remaining_str wall_time_spent = - wall_time_spent percent_complete = "$percent_complete%" sypd = sypd date_now = - Dates.now() estimated_finish_date = estimated_finish_date - - # the first fixed increment is equivalent to - # doubling (which puts us at 10%), so we check - # if we're below 5%. - if percent_complete < 5 - # doubling factor (to reduce log noise) - wte.n_next *= 2 - else - if wte.n_fixed_increment == -1 - wte.n_fixed_increment = wte.n_next - end - # increase by fixed increment after 10% - # completion to maintain logs after 50%. - wte.n_next += wte.n_fixed_increment - end - end - wte.n_calls += 1 - - return nothing -end - function gc_func(integrator) num_pre = Base.gc_num() alloc_since_last = (num_pre.allocd + num_pre.deferred_alloc) / 2^20 diff --git a/src/callbacks/get_callbacks.jl b/src/callbacks/get_callbacks.jl index fb8f6f36a7d..ebf1a4dfebf 100644 --- a/src/callbacks/get_callbacks.jl +++ b/src/callbacks/get_callbacks.jl @@ -218,14 +218,17 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start) callbacks = () if parsed_args["log_progress"] - @info "Progress logging enabled." - callbacks = ( - callbacks..., - call_every_n_steps( - (integrator) -> print_walltime_estimate(integrator); - skip_first = true, - ), - ) + @info "Progress logging enabled" + walltime_info = WallTimeInfo() + tot_steps = ceil(Int, (sim_info.t_end - t_start) / dt) + five_percent_steps = ceil(Int, 0.05 * tot_steps) + cond = let schedule = CappedGeometricSeriesSchedule(five_percent_steps) + (u, t, integrator) -> schedule(integrator) + end + affect! = let wt = walltime_info + (integrator) -> report_walltime(wt, integrator) + end + callbacks = (callbacks..., SciMLBase.DiscreteCallback(cond, affect!)) end check_nan_every = parsed_args["check_nan_every"] if check_nan_every > 0 diff --git a/test/coupler_compatibility.jl b/test/coupler_compatibility.jl index e7b063d4b95..c5b5c9f8d3a 100644 --- a/test/coupler_compatibility.jl +++ b/test/coupler_compatibility.jl @@ -66,8 +66,6 @@ const T2 = 290 @. sfc_setup = (surface_state,) p_overwritten = CA.AtmosCache( p.dt, - simulation.t_end, - CA.WallTimeEstimate(), p.start_date, p.atmos, p.numerics,