Skip to content

Commit

Permalink
Merge pull request #3299 from CliMA/ck/inference3
Browse files Browse the repository at this point in the history
Convert bools -> types to improve inference
  • Loading branch information
charleskawczynski authored Sep 16, 2024
2 parents b1c50ea + 84aea3f commit 70463c6
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 89 deletions.
27 changes: 15 additions & 12 deletions src/cache/diagnostic_edmf_precomputed_quantities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!(
TD.relative_humidity(thermo_params, ts_prev_level),
FT(0),
tke_prev_level,
p.atmos.edmfx_entr_model,
p.atmos.edmfx_model.entr_model,
)

# We don't have an upper limit to entrainment for the first level
Expand All @@ -422,16 +422,19 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!(
)

# TODO: use updraft top instead of scale height
@. nh_pressure³ʲ_prev_halflevel = ᶠupdraft_nh_pressure(
params,
p.atmos.edmfx_nh_pressure,
local_geometry_prev_halflevel,
-∇Φ³_prev_level * (ρʲ_prev_level - ρ_prev_level) /
ρʲ_prev_level,
u³ʲ_prev_halflevel,
u³⁰_prev_halflevel,
scale_height,
)
if p.atmos.edmfx_model.nh_pressure isa Val{true}
@. nh_pressure³ʲ_prev_halflevel = ᶠupdraft_nh_pressure(
params,
local_geometry_prev_halflevel,
-∇Φ³_prev_level * (ρʲ_prev_level - ρ_prev_level) /
ρʲ_prev_level,
u³ʲ_prev_halflevel,
u³⁰_prev_halflevel,
scale_height,
)
else
@. nh_pressure³ʲ_prev_halflevel = CT3(0)
end

nh_pressure³ʲ_data_prev_halflevel =
nh_pressure³ʲ_prev_halflevel.components.data.:1
Expand Down Expand Up @@ -558,7 +561,7 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!(
vert_div_level,
FT(0), # mass flux divergence is not implemented for diagnostic edmf
tke_prev_level,
p.atmos.edmfx_detr_model,
p.atmos.edmfx_model.detr_model,
)

@. detrʲ_prev_level = limit_detrainment(
Expand Down
4 changes: 2 additions & 2 deletions src/cache/prognostic_edmf_precomputed_quantities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ NVTX.@annotate function set_prognostic_edmf_precomputed_quantities_closures!(
TD.relative_humidity(thermo_params, ᶜts⁰),
FT(0),
max(ᶜtke⁰, 0),
p.atmos.edmfx_entr_model,
p.atmos.edmfx_model.entr_model,
)
@. ᶜentrʲs.:($$j) = limit_entrainment(
ᶜentrʲs.:($$j),
Expand Down Expand Up @@ -258,7 +258,7 @@ NVTX.@annotate function set_prognostic_edmf_precomputed_quantities_closures!(
ᶜvert_div,
ᶜmassflux_vert_div,
ᶜtke⁰,
p.atmos.edmfx_detr_model,
p.atmos.edmfx_model.detr_model,
)
@. ᶜdetrʲs.:($$j) = limit_detrainment(
ᶜdetrʲs.:($$j),
Expand Down
64 changes: 26 additions & 38 deletions src/prognostic_equations/edmfx_closures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,42 +55,28 @@ end
Inputs (everything defined on cell faces):
- params - set with model parameters
- nh_presssure_flag - bool flag for if we want/don't want to compute
pressure drag
- ᶠlg - local geometry (needed to compute the norm inside a local function)
- ᶠbuoyʲ - covariant3 or contravariant3 updraft buoyancy
- ᶠu3ʲ, ᶠu3⁰ - covariant3 or contravariant3 velocity for updraft and environment.
covariant3 velocity is used in prognostic edmf, and contravariant3
velocity is used in diagnostic edmf.
- updraft top height
"""
function ᶠupdraft_nh_pressure(
params,
nh_pressure_flag,
ᶠlg,
ᶠbuoyʲ,
ᶠu3ʲ,
ᶠu3⁰,
plume_height,
)
if !nh_pressure_flag
return zero(ᶠu3ʲ)
else
turbconv_params = CAP.turbconv_params(params)
# factor multiplier for pressure buoyancy terms (effective buoyancy is (1-α_b))
α_b = CAP.pressure_normalmode_buoy_coeff1(turbconv_params)
# factor multiplier for pressure drag
α_d = CAP.pressure_normalmode_drag_coeff(turbconv_params)
function ᶠupdraft_nh_pressure(params, ᶠlg, ᶠbuoyʲ, ᶠu3ʲ, ᶠu3⁰, plume_height)
turbconv_params = CAP.turbconv_params(params)
# factor multiplier for pressure buoyancy terms (effective buoyancy is (1-α_b))
α_b = CAP.pressure_normalmode_buoy_coeff1(turbconv_params)
# factor multiplier for pressure drag
α_d = CAP.pressure_normalmode_drag_coeff(turbconv_params)

# Independence of aspect ratio hardcoded: α₂_asp_ratio² = FT(0)
# Independence of aspect ratio hardcoded: α₂_asp_ratio² = FT(0)

H_up_min = CAP.min_updraft_top(turbconv_params)
H_up_min = CAP.min_updraft_top(turbconv_params)

# We also used to have advection term here: α_a * w_up * div_w_up
return α_b * ᶠbuoyʲ +
α_d * (ᶠu3ʲ - ᶠu3⁰) * CC.Geometry._norm(ᶠu3ʲ - ᶠu3⁰, ᶠlg) /
max(plume_height, H_up_min)
end
# We also used to have advection term here: α_a * w_up * div_w_up
return α_b * ᶠbuoyʲ +
α_d * (ᶠu3ʲ - ᶠu3⁰) * CC.Geometry._norm(ᶠu3ʲ - ᶠu3⁰, ᶠlg) /
max(plume_height, H_up_min)
end

edmfx_nh_pressure_tendency!(Yₜ, Y, p, t, turbconv_model) = nothing
Expand All @@ -111,17 +97,19 @@ function edmfx_nh_pressure_tendency!(
scale_height = CAP.R_d(params) * CAP.T_surf_ref(params) / CAP.grav(params)

for j in 1:n
@. ᶠnh_pressure₃ʲs.:($$j) = ᶠupdraft_nh_pressure(
params,
p.atmos.edmfx_nh_pressure,
ᶠlg,
ᶠbuoyancy(ᶠinterp(Y.c.ρ), ᶠinterp(ᶜρʲs.:($$j)), ᶠgradᵥ_ᶜΦ),
Y.f.sgsʲs.:($$j).u₃,
ᶠu₃⁰,
scale_height,
)

@. Yₜ.f.sgsʲs.:($$j).u₃ -= ᶠnh_pressure₃ʲs.:($$j)
if p.atmos.edmfx_model.nh_pressure isa Val{true}
@. ᶠnh_pressure₃ʲs.:($$j) = ᶠupdraft_nh_pressure(
params,
ᶠlg,
ᶠbuoyancy(ᶠinterp(Y.c.ρ), ᶠinterp(ᶜρʲs.:($$j)), ᶠgradᵥ_ᶜΦ),
Y.f.sgsʲs.:($$j).u₃,
ᶠu₃⁰,
scale_height,
)
@. Yₜ.f.sgsʲs.:($$j).u₃ -= ᶠnh_pressure₃ʲs.:($$j)
else
@. ᶠnh_pressure₃ʲs.:($$j) = C3(0)
end
end
end

Expand Down Expand Up @@ -299,7 +287,7 @@ function edmfx_filter_tendency!(Yₜ, Y, p, t, turbconv_model::PrognosticEDMFX)
n = n_mass_flux_subdomains(turbconv_model)
(; dt) = p

if p.atmos.edmfx_filter
if p.atmos.edmfx_model.filter isa Val{true}
for j in 1:n
@. Yₜ.f.sgsʲs.:($$j).u₃ -=
C3(min(Y.f.sgsʲs.:($$j).u₃.components.data.:1, 0)) / dt
Expand Down
8 changes: 4 additions & 4 deletions src/prognostic_equations/edmfx_sgs_flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function edmfx_sgs_mass_flux_tendency!(
(; dt) = p
ᶜJ = Fields.local_geometry_field(Y.c).J

if p.atmos.edmfx_sgs_mass_flux
if p.atmos.edmfx_model.sgs_mass_flux isa Val{true}
# energy
ᶠu³_diff = p.scratch.ᶠtemp_CT3
ᶜa_scalar = p.scratch.ᶜtemp_scalar
Expand Down Expand Up @@ -105,7 +105,7 @@ function edmfx_sgs_mass_flux_tendency!(
ᶜJ = Fields.local_geometry_field(Y.c).J
FT = eltype(Y)

if p.atmos.edmfx_sgs_mass_flux
if p.atmos.edmfx_model.sgs_mass_flux isa Val{true}
# energy
ᶠu³_diff = p.scratch.ᶠtemp_CT3
ᶜa_scalar = p.scratch.ᶜtemp_scalar
Expand Down Expand Up @@ -189,7 +189,7 @@ function edmfx_sgs_diffusive_flux_tendency!(
(; ᶜK_u, ᶜK_h, ρatke_flux) = p.precomputed
ᶠgradᵥ = Operators.GradientC2F()

if p.atmos.edmfx_sgs_diffusive_flux
if p.atmos.edmfx_model.sgs_diffusive_flux isa Val{true}
ᶠρaK_h = p.scratch.ᶠtemp_scalar
@. ᶠρaK_h = ᶠinterp(ᶜρa⁰) * ᶠinterp(ᶜK_h)
ᶠρaK_u = p.scratch.ᶠtemp_scalar
Expand Down Expand Up @@ -258,7 +258,7 @@ function edmfx_sgs_diffusive_flux_tendency!(
(; ᶜK_u, ᶜK_h, ρatke_flux) = p.precomputed
ᶠgradᵥ = Operators.GradientC2F()

if p.atmos.edmfx_sgs_diffusive_flux
if p.atmos.edmfx_model.sgs_diffusive_flux isa Val{true}
ᶠρaK_h = p.scratch.ᶠtemp_scalar
@. ᶠρaK_h = ᶠinterp(Y.c.ρ) * ᶠinterp(ᶜK_h)
ᶠρaK_u = p.scratch.ᶠtemp_scalar
Expand Down
31 changes: 10 additions & 21 deletions src/solver/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,21 @@ function get_atmos(config::AtmosConfig, params)
advection_test = parsed_args["advection_test"]
@assert advection_test in (false, true)

edmfx_entr_model = get_entrainment_model(parsed_args)
edmfx_detr_model = get_detrainment_model(parsed_args)

edmfx_sgs_mass_flux = parsed_args["edmfx_sgs_mass_flux"]
@assert edmfx_sgs_mass_flux in (false, true)

edmfx_sgs_diffusive_flux = parsed_args["edmfx_sgs_diffusive_flux"]
@assert edmfx_sgs_diffusive_flux in (false, true)

edmfx_nh_pressure = parsed_args["edmfx_nh_pressure"]
@assert edmfx_nh_pressure in (false, true)

edmfx_filter = parsed_args["edmfx_filter"]
@assert edmfx_filter in (false, true)

implicit_diffusion = parsed_args["implicit_diffusion"]
@assert implicit_diffusion in (true, false)

implicit_sgs_advection = parsed_args["implicit_sgs_advection"]
@assert implicit_sgs_advection in (true, false)

edmfx_model = EDMFXModel(;
entr_model = get_entrainment_model(parsed_args),
detr_model = get_detrainment_model(parsed_args),
sgs_mass_flux = Val(parsed_args["edmfx_sgs_mass_flux"]),
sgs_diffusive_flux = Val(parsed_args["edmfx_sgs_diffusive_flux"]),
nh_pressure = Val(parsed_args["edmfx_nh_pressure"]),
filter = Val(parsed_args["edmfx_filter"]),
)

model_config = get_model_config(parsed_args)
vert_diff =
get_vertical_diffusion_model(diffuse_momentum, parsed_args, params, FT)
Expand All @@ -65,12 +59,7 @@ function get_atmos(config::AtmosConfig, params)
edmf_coriolis = get_edmf_coriolis(parsed_args, FT),
advection_test,
tendency_model = get_tendency_model(parsed_args),
edmfx_entr_model,
edmfx_detr_model,
edmfx_sgs_mass_flux,
edmfx_sgs_diffusive_flux,
edmfx_nh_pressure,
edmfx_filter,
edmfx_model,
precip_model,
cloud_model,
forcing_type,
Expand Down
32 changes: 20 additions & 12 deletions src/solver/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,24 @@ function Base.summary(io::IO, numerics::AtmosNumerics)
end
end

const ValTF = Union{Val{true}, Val{false}}

Base.@kwdef struct EDMFXModel{
EEM,
EDM,
ESMF <: ValTF,
ESDF <: ValTF,
ENP <: ValTF,
EVR <: ValTF,
}
entr_model::EEM = nothing
detr_model::EDM = nothing
sgs_mass_flux::ESMF = Val(false)
sgs_diffusive_flux::ESDF = Val(false)
nh_pressure::ENP = Val(false)
filter::EVR = Val(false)
end

Base.@kwdef struct AtmosModel{
MC,
MM,
Expand All @@ -356,12 +374,7 @@ Base.@kwdef struct AtmosModel{
EC,
AT,
TM,
EEM,
EDM,
ESMF,
ESDF,
ENP,
EVR,
EDMFX,
TCM,
NOGW,
OGW,
Expand Down Expand Up @@ -390,12 +403,7 @@ Base.@kwdef struct AtmosModel{
edmf_coriolis::EC = nothing
advection_test::AT = nothing
tendency_model::TM = nothing
edmfx_entr_model::EEM = nothing
edmfx_detr_model::EDM = nothing
edmfx_sgs_mass_flux::ESMF = nothing
edmfx_sgs_diffusive_flux::ESDF = nothing
edmfx_nh_pressure::ENP = nothing
edmfx_filter::EVR = nothing
edmfx_model::EDMFX = nothing
turbconv_model::TCM = nothing
non_orographic_gravity_wave::NOGW = nothing
orographic_gravity_wave::OGW = nothing
Expand Down

0 comments on commit 70463c6

Please sign in to comment.