Skip to content

Commit

Permalink
Cache: remove do_dss
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed Oct 25, 2023
1 parent a1ca826 commit 8bc938d
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 19 deletions.
20 changes: 14 additions & 6 deletions src/cache/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ using ClimaCore.Utilities: half

import ClimaCore.Fields: ColumnField

function do_dss(space::Spaces.AbstractSpace)
quadrature_style = Spaces.horizontal_space(space).quadrature_style
return quadrature_style isa Spaces.Quadratures.GLL
end

do_dss(Yc::Fields.Field) = do_dss(axes(Yc))
do_dss(Y::Fields.FieldVector) = do_dss(Y.c)

# Functions on which the model depends:
# CAP.R_d(params) # dry specific gas constant
# CAP.kappa_d(params) # dry adiabatic exponent
Expand Down Expand Up @@ -60,11 +68,12 @@ function default_cache(
end
ᶜf = @. CT3(Geometry.WVector(ᶜf))

quadrature_style = Spaces.horizontal_space(axes(Y.c)).quadrature_style
do_dss = quadrature_style isa Spaces.Quadratures.GLL
ghost_buffer =
!do_dss ? (;) :
(; c = Spaces.create_dss_buffer(Y.c), f = Spaces.create_dss_buffer(Y.f))
do_dss(Y) ?
(;
c = Spaces.create_dss_buffer(Y.c),
f = Spaces.create_dss_buffer(Y.f),
) : (;)

limiter =
isnothing(numerics.limiter) ? nothing :
Expand All @@ -87,14 +96,13 @@ function default_cache(
ᶜf,
∂ᶜK_∂ᶠu₃ = similar(Y.c, BidiagonalMatrixRow{Adjoint{FT, CT3{FT}}}),
params,
do_dss,
ghost_buffer,
net_energy_flux_toa,
net_energy_flux_sfc,
env_thermo_quad = SGSQuadrature(FT),
precomputed_quantities(Y, atmos)...,
temporary_quantities(atmos, spaces.center_space, spaces.face_space)...,
hyperdiffusion_cache(Y, atmos, do_dss)...,
hyperdiffusion_cache(Y, atmos)...,
)
set_precomputed_quantities!(Y, default_cache, FT(0))
default_cache.is_init[] = false
Expand Down
2 changes: 1 addition & 1 deletion src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ NVTX.@annotate function compute_diagnostics(integrator)

if eltype(Fields.coordinate_field(axes(Y.c))) <: Geometry.Abstract3DPoint
ᶜvort = @. Geometry.WVector(curlₕ(Y.c.uₕ))
if p.do_dss
if do_dss(Y)
Spaces.weighted_dss!(ᶜvort)
end
dycore_diagnostic = (; dycore_diagnostic..., vorticity = ᶜvort)
Expand Down
2 changes: 1 addition & 1 deletion src/prognostic_equations/dss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using ClimaCore.Utilities: half
import ClimaCore.Fields: ColumnField

NVTX.@annotate function dss!(Y, p, t)
if p.do_dss
if do_dss(Y)
Spaces.weighted_dss_start2!(Y.c, p.ghost_buffer.c)
Spaces.weighted_dss_start2!(Y.f, p.ghost_buffer.f)
Spaces.weighted_dss_internal2!(Y.c, p.ghost_buffer.c)
Expand Down
22 changes: 11 additions & 11 deletions src/prognostic_equations/hyperdiffusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import ClimaCore.Geometry as Geometry
import ClimaCore.Fields as Fields
import ClimaCore.Spaces as Spaces

function hyperdiffusion_cache(Y, atmos, do_dss)
function hyperdiffusion_cache(Y, atmos)
isnothing(atmos.hyperdiff) && return (;)
FT = eltype(Y)
n = n_mass_flux_subdomains(atmos.turbconv_model)
Expand Down Expand Up @@ -35,7 +35,7 @@ function hyperdiffusion_cache(Y, atmos, do_dss)
atmos.turbconv_model isa DiagnosticEDMFX ?
(; ᶜ∇²tke⁰ = similar(Y.c, FT)) : (;)
quantities = (; gs_quantities..., sgs_quantities...)
if do_dss
if do_dss(Y)
quantities = (;
quantities...,
hyperdiffusion_ghost_buffer = map(
Expand All @@ -56,14 +56,14 @@ NVTX.@annotate function hyperdiffusion_tendency!(Yₜ, Y, p, t)
diffuse_tke = use_prognostic_tke(turbconv_model)
ᶜJ = Fields.local_geometry_field(Y.c).J
point_type = eltype(Fields.coordinate_field(Y.c))
(; do_dss, ᶜp, ᶜspecific, ᶜ∇²u, ᶜ∇²specific_energy) = p
(; ᶜp, ᶜspecific, ᶜ∇²u, ᶜ∇²specific_energy) = p
if turbconv_model isa PrognosticEDMFX
(; ᶜρa⁰, ᶜρʲs, ᶜ∇²tke⁰, ᶜtke⁰, ᶜ∇²uₕʲs, ᶜ∇²uᵥʲs, ᶜ∇²uʲs, ᶜ∇²h_totʲs) = p
end
if turbconv_model isa DiagnosticEDMFX
(; ᶜtke⁰, ᶜ∇²tke⁰) = p
end
if do_dss
if do_dss(Y)
buffer = p.hyperdiffusion_ghost_buffer
end

Expand All @@ -89,24 +89,24 @@ NVTX.@annotate function hyperdiffusion_tendency!(Yₜ, Y, p, t)
end
end

if do_dss
if do_dss(Y)
NVTX.@range "dss_hyperdiffusion_tendency" color = colorant"green" begin
for dss_op! in (
Spaces.weighted_dss_start!,
Spaces.weighted_dss_internal!,
Spaces.weighted_dss_ghost!,
)
# DSS on Grid scale quantities
# Need to split the DSS computation here, because our DSS
# Need to split the DSS computation here, because our DSS
# operations do not accept Covariant123Vector types
dss_op!(ᶜ∇²u, buffer.ᶜ∇²u)
dss_op!(ᶜ∇²specific_energy, buffer.ᶜ∇²specific_energy)
if diffuse_tke
dss_op!(ᶜ∇²tke⁰, buffer.ᶜ∇²tke⁰)
end
if turbconv_model isa PrognosticEDMFX
# Need to split the DSS computation here, because our DSS
# operations do not accept Covariant123Vector types
# Need to split the DSS computation here, because our DSS
# operations do not accept Covariant123Vector types
for j in 1:n
@. ᶜ∇²uₕʲs.:($$j) = C12(ᶜ∇²uʲs.:($$j))
@. ᶜ∇²uᵥʲs.:($$j) = C3(ᶜ∇²uʲs.:($$j))
Expand Down Expand Up @@ -162,11 +162,11 @@ NVTX.@annotate function tracer_hyperdiffusion_tendency!(Yₜ, Y, p, t)
(; κ₄) = hyperdiff
n = n_mass_flux_subdomains(turbconv_model)

(; do_dss, ᶜspecific, ᶜ∇²specific_tracers) = p
(; ᶜspecific, ᶜ∇²specific_tracers) = p
if turbconv_model isa PrognosticEDMFX
(; ᶜ∇²q_totʲs) = p
end
if do_dss
if do_dss(Y)
buffer = p.hyperdiffusion_ghost_buffer
end

Expand All @@ -181,7 +181,7 @@ NVTX.@annotate function tracer_hyperdiffusion_tendency!(Yₜ, Y, p, t)
end
end

if do_dss
if do_dss(Y)
NVTX.@range "dss_hyperdiffusion_tendency" color = colorant"green" begin
for dss_op! in (
Spaces.weighted_dss_start!,
Expand Down

0 comments on commit 8bc938d

Please sign in to comment.