From d688218a0793f36a7e3d175f5c629aed955d269a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E8=90=A7=E6=B6=AF?= Date: Fri, 12 Jul 2024 11:16:43 -0700 Subject: [PATCH 1/2] refactor gravity wave cache and tendency --- .../non_orographic_gravity_wave.jl | 156 ++++++++++++------ .../nogw_test_3d.jl | 3 + .../nogw_test_mima.jl | 4 +- .../nogw_test_single_column.jl | 14 +- 4 files changed, 125 insertions(+), 52 deletions(-) diff --git a/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl b/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl index c71e424ae3..cd82deacc3 100644 --- a/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl +++ b/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl @@ -5,6 +5,7 @@ import ClimaCore.Spaces as Spaces import ClimaCore.Fields as Fields import ClimaCore.Geometry as Geometry +import ClimaCore.Operators as Operators non_orographic_gravity_wave_cache(Y, atmos::AtmosModel) = non_orographic_gravity_wave_cache( @@ -27,12 +28,18 @@ function non_orographic_gravity_wave_cache( nc = Int(floor(FT(2 * cmax / dc + 1))) c = [FT((n - 1) * dc - cmax) for n in 1:nc] - + source_level_z = similar(Fields.level(Y.c.ρ, 1), Tuple{FT, FT}) + ᶜlevel = similar(Y.c.ρ, FT) + for i in 1:Spaces.nlevels(axes(Y.c.ρ)) + fill!(Fields.level(ᶜlevel, i), i) + end + damp_level_z = similar(source_level_z) return (; gw_source_height = source_height, gw_source_ampl = Bt_0 .* ones(FT, axes(Fields.level(Y.c.ρ, 1))), gw_Bw = Bw .* ones(FT, axes(Fields.level(Y.c.ρ, 1))), gw_Bn = Bn .* ones(FT, axes(Fields.level(Y.c.ρ, 1))), + gw_B0 = similar(c), gw_c = c, gw_cw = cw .* ones(FT, axes(Fields.level(Y.c.ρ, 1))), gw_cn = cn .* ones(FT, axes(Fields.level(Y.c.ρ, 1))), @@ -41,6 +48,15 @@ function non_orographic_gravity_wave_cache( gw_nk = Int(nk), ᶜbuoyancy_frequency = similar(Y.c.ρ), ᶜdTdz = similar(Y.c.ρ), + source_level_z, + damp_level_z, + source_level = similar(Fields.level(Y.c.ρ, 1)), + damp_level = similar(Fields.level(Y.c.ρ, 1)), + ᶜlevel, + u_phy = similar(Y.c.ρ), + v_phy = similar(Y.c.ρ), + uforcing = similar(Y.c.ρ), + vforcing = similar(Y.c.ρ), ) end @@ -66,6 +82,13 @@ function non_orographic_gravity_wave_cache( gw_Bw = ones(FT, axes(lat)) .* Bw gw_cn = ones(FT, axes(lat)) .* cn + source_level_z = similar(Fields.level(Y.c.ρ, 1), Tuple{FT, FT}) + ᶜlevel = similar(Y.c.ρ, FT) + for i in 1:Spaces.nlevels(axes(Y.c.ρ)) + fill!(Fields.level(ᶜlevel, i), i) + end + damp_level_z = similar(source_level_z) + # This is GFDL source specs -> a smooth function # source_ampl = @. Bt_0 + # Bt_n * FT(0.5) * (FT(1) + tanh((lat - ϕ0_n) / dϕ_n)) + @@ -94,6 +117,7 @@ function non_orographic_gravity_wave_cache( gw_source_ampl = source_ampl, gw_Bw = gw_Bw, gw_Bn = gw_Bn, + gw_B0 = similar(c), gw_c = c, gw_cw = gw_cw, gw_cn = gw_cn, @@ -102,6 +126,15 @@ function non_orographic_gravity_wave_cache( gw_nk = Int(nk), ᶜbuoyancy_frequency = similar(Y.c.ρ), ᶜdTdz = similar(Y.c.ρ), + source_level_z, + damp_level_z, + source_level = similar(Fields.level(Y.c.ρ, 1)), + damp_level = similar(Fields.level(Y.c.ρ, 1)), + ᶜlevel, + u_phy = similar(Y.c.ρ), + v_phy = similar(Y.c.ρ), + uforcing = similar(Y.c.ρ), + vforcing = similar(Y.c.ρ), ) end @@ -116,12 +149,25 @@ function non_orographic_gravity_wave_tendency!( (; ᶜT,) = p.core (; ᶜts) = p.precomputed (; params) = p - (; ᶜdTdz, ᶜbuoyancy_frequency) = p.non_orographic_gravity_wave + (; + ᶜdTdz, + ᶜbuoyancy_frequency, + source_level, + source_level_z, + damp_level, + damp_level_z, + u_phy, + v_phy, + uforcing, + vforcing, + ᶜlevel, + ) = p.non_orographic_gravity_wave (; model_config) = p.atmos (; gw_source_ampl, gw_Bw, gw_Bn, + gw_B0, gw_c, gw_cw, gw_cn, @@ -145,10 +191,10 @@ function non_orographic_gravity_wave_tendency!( # compute buoyancy frequency @. ᶜT = TD.air_temperature(thermo_params, ᶜts) - parent(ᶜdTdz) .= parent(Geometry.WVector.(ᶜgradᵥ.(ᶠinterp.(ᶜT)))) + ᶜdTdz .= Geometry.WVector.(ᶜgradᵥ.(ᶠinterp.(ᶜT))).components.data.:1 - ᶜbuoyancy_frequency = - @. (grav / ᶜT) * (ᶜdTdz + grav / TD.cp_m(thermo_params, ᶜts)) + @. ᶜbuoyancy_frequency = + (grav / ᶜT) * (ᶜdTdz + grav / TD.cp_m(thermo_params, ᶜts)) ᶜbuoyancy_frequency = @. ifelse( ᶜbuoyancy_frequency < FT(2.5e-5), FT(sqrt(2.5e-5)), @@ -157,56 +203,65 @@ function non_orographic_gravity_wave_tendency!( if model_config isa SingleColumnModel # source level: the index of the level that is closest to the source height - source_level = similar(Fields.level(Y.c.ρ, 1)) - Fields.bycolumn(axes(ᶜρ)) do colidx - parent(source_level[colidx]) .= - argmin(abs.(parent(ᶜz[colidx]) .- gw_source_height))[1] - end - # damp level: for now we only deposit to top level for column setup - damp_level = similar(Fields.level(Y.c.ρ, 1)) - Fields.bycolumn(axes(ᶜρ)) do colidx - parent(damp_level[colidx]) .= length(parent(ᶜz[colidx])) + + Operators.column_mapreduce!( + reduce_fun1, + source_level_z, + ᶜz, + ᶜlevel, + ) do z, level + (abs.(z .- gw_source_height), level) end + source_level = source_level_z.:2 + + Operators.column_mapreduce!(sign, +, damp_level, ᶜz) + elseif model_config isa SphericalModel (; ᶜp) = p.precomputed # source level: the index of the highest level whose pressure is higher than source pressure - source_level = similar(Fields.level(Y.c.ρ, 1)) - Fields.bycolumn(axes(ᶜρ)) do colidx - parent(source_level[colidx]) .= - findlast(parent(ᶜp[colidx]) .> gw_source_pressure)[1] + + Operators.column_mapreduce!( + reduce_fun2, + source_level_z, + ᶜp, + ᶜlevel, + ) do p, level + (p .- gw_source_pressure, level) end + source_level = source_level_z.:2 + + # damp level: the index of the lowest level whose pressure is lower than the damp pressure - damp_level = similar(Fields.level(Y.c.ρ, 1)) - Fields.bycolumn(axes(ᶜρ)) do colidx - if sum(parent(ᶜp[colidx]) .< gw_damp_pressure) == 0 - parent(damp_level[colidx]) .= length(parent(ᶜz[colidx])) - else - parent(damp_level[colidx]) .= - findfirst(parent(ᶜp[colidx]) .< gw_damp_pressure)[1] - end + + Operators.column_mapreduce!( + reduce_fun3, + damp_level_z, + ᶜp, + ᶜlevel, + ) do p, level + (p .- gw_damp_pressure, level) end + damp_level = damp_level_z.:2 + end # prepare physical uv input variables for gravity_wave_forcing() u_phy = Geometry.UVVector.(Y.c.uₕ).components.data.:1 v_phy = Geometry.UVVector.(Y.c.uₕ).components.data.:2 - # a place holder to store physical forcing on uv - uforcing = ones(axes(u_phy)) - vforcing = ones(axes(u_phy)) - # GW parameterization applied bycolume Fields.bycolumn(axes(ᶜρ)) do colidx parent(uforcing[colidx]) .= non_orographic_gravity_wave_forcing( - copy(vec(parent(u_phy[colidx]))), - copy(vec(parent(ᶜbuoyancy_frequency[colidx]))), - copy(vec(parent(ᶜρ[colidx]))), - copy(vec(parent(ᶜz[colidx]))), + vec(parent(u_phy[colidx])), + vec(parent(ᶜbuoyancy_frequency[colidx])), + vec(parent(ᶜρ[colidx])), + vec(parent(ᶜz[colidx])), Int(parent(source_level[colidx])[1]), Int(parent(damp_level[colidx])[1]), parent(gw_source_ampl[colidx])[1], parent(gw_Bw[colidx])[1], parent(gw_Bn[colidx])[1], + gw_B0, parent(gw_cw[colidx])[1], parent(gw_cn[colidx])[1], parent(gw_flag[colidx])[1], @@ -216,15 +271,16 @@ function non_orographic_gravity_wave_tendency!( ) parent(vforcing[colidx]) .= non_orographic_gravity_wave_forcing( - copy(vec(parent(v_phy[colidx]))), - copy(vec(parent(ᶜbuoyancy_frequency[colidx]))), - copy(vec(parent(ᶜρ[colidx]))), - copy(vec(parent(ᶜz[colidx]))), + vec(parent(v_phy[colidx])), + vec(parent(ᶜbuoyancy_frequency[colidx])), + vec(parent(ᶜρ[colidx])), + vec(parent(ᶜz[colidx])), Int(parent(source_level[colidx])[1]), Int(parent(damp_level[colidx])[1]), parent(gw_source_ampl[colidx])[1], parent(gw_Bw)[1], parent(gw_Bn)[1], + gw_B0, parent(gw_cw)[1], parent(gw_cn)[1], parent(gw_flag)[1], @@ -242,15 +298,16 @@ function non_orographic_gravity_wave_tendency!( end function non_orographic_gravity_wave_forcing( - ᶜu, - ᶜbf, - ᶜρ, - ᶜz, + old_ᶜu, + old_ᶜbf, + old_ᶜρ, + old_ᶜz, source_level, damp_level, source_ampl, Bw, Bn, + B0, cw, cn, flag, @@ -258,17 +315,18 @@ function non_orographic_gravity_wave_forcing( c0, nk, ) - FT = eltype(ᶜz) + FT = eltype(old_ᶜz) # add an extra layer above model top so that forcing between the very top # model layer and the upper boundary can be calculated - append!(ᶜu, FT(2) * ᶜu[end] - ᶜu[end - 1]) - append!(ᶜρ, ᶜρ[end] * ᶜρ[end] / ᶜρ[end - 1]) - append!(ᶜbf, ᶜbf[end]) - append!(ᶜz, FT(2) * ᶜz[end] - ᶜz[end - 1]) + ᶜu = vcat(old_ᶜu, FT(2) * old_ᶜu[end] - old_ᶜu[end - 1]) + ᶜρ = vcat(old_ᶜρ, old_ᶜρ[end] * old_ᶜρ[end] / old_ᶜρ[end - 1]) + ᶜbf = vcat(old_ᶜbf, old_ᶜbf[end]) + ᶜz = vcat(old_ᶜz, FT(2) * old_ᶜz[end] - old_ᶜz[end - 1]) # wave spectra and the source amplitude nc = length(c) c_hat0 = c .- ᶜu[source_level] # c0mu0 + Bw_exp = @. exp(-log(2.0) * ((c * flag + c_hat0 * (1 - flag) - c0) / cw)^2) Bn_exp = @. exp(-log(2.0) * ((c * flag + c_hat0 * (1 - flag) - c0) / cn)^2) B0 = @. sign(c_hat0) * (Bw * Bw_exp + Bn * Bn_exp) @@ -388,3 +446,7 @@ end function calc_intermitency(ρ_source_level, source_ampl, nk, Bsum) return (source_ampl / ρ_source_level / nk) / Bsum end + +@inline reduce_fun1(a, b) = ifelse(a[1] < b[1], a, b) +@inline reduce_fun2(a, b) = ifelse(b[1] < 0, a, b) +@inline reduce_fun3(a, b) = ifelse(a[1] > 0, b, a) diff --git a/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_3d.jl b/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_3d.jl index d55ee0cedb..45cbe75b0c 100644 --- a/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_3d.jl +++ b/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_3d.jl @@ -122,6 +122,8 @@ center_u_zonalave = mean(center_u, dims = 1)[1, :, :, :] center_bf_zonalave = mean(center_bf, dims = 1)[1, :, :, :] center_ρ_zonalave = mean(center_ρ, dims = 1)[1, :, :, :] +B0 = similar(params.gw_c) + # Jan month = Dates.month.(time) @@ -140,6 +142,7 @@ for j in 1:length(lat) params.gw_source_ampl, params.gw_Bw, params.gw_Bn, + B0, params.gw_cw, params.gw_cn, params.gw_flag, diff --git a/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_mima.jl b/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_mima.jl index 24c3b75613..13719a2589 100644 --- a/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_mima.jl +++ b/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_mima.jl @@ -125,7 +125,7 @@ bf = @. ifelse(bf < 2.5e-5, sqrt(2.5e-5), sqrt(abs(bf))) # compute u/v forcings from convective gravity waves params = non_orographic_gravity_wave(lat, FT) - +B0 = similar(params.gw_c) # nogw forcing kmax = length(pfull) - 1 @@ -156,6 +156,7 @@ for i in 1:length(lon) params.gw_source_ampl[j], params.gw_Bw, params.gw_Bn[j], + B0, params.gw_cw[j], params.gw_cn, params.gw_flag[j], @@ -190,6 +191,7 @@ for i in 1:length(lon) params.gw_source_ampl[j], params.gw_Bw, params.gw_Bn[j], + B0, params.gw_cw[j], params.gw_cn, params.gw_flag[j], diff --git a/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_single_column.jl b/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_single_column.jl index aae371a941..aa98a18a51 100644 --- a/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_single_column.jl +++ b/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_single_column.jl @@ -129,6 +129,8 @@ ENV["GKSwstype"] = "nul" output_dir = "nonorographic_gravity_wave_test_single_column" mkpath(output_dir) +B0 = similar(params.gw_c) + # Jan Jan_u = mean(center_u_mean[:, month .== 1], dims = 2)[:, 1] Jan_bf = mean(center_bf_mean[:, month .== 1], dims = 2)[:, 1] @@ -137,12 +139,13 @@ Jan_uforcing = CA.non_orographic_gravity_wave_forcing( Jan_u, Jan_bf, Jan_ρ, - copy(center_z), + center_z, source_level, damp_level, params.gw_source_ampl, params.gw_Bw, params.gw_Bn, + B0, params.gw_cw, params.gw_cn, params.gw_flag, @@ -167,12 +170,13 @@ April_uforcing = CA.non_orographic_gravity_wave_forcing( April_u, April_bf, April_ρ, - copy(center_z), + center_z, source_level, damp_level, params.gw_source_ampl, params.gw_Bw, params.gw_Bn, + B0, params.gw_cw, params.gw_cn, params.gw_flag, @@ -197,12 +201,13 @@ July_uforcing = CA.non_orographic_gravity_wave_forcing( July_u, July_bf, July_ρ, - copy(center_z), + center_z, source_level, damp_level, params.gw_source_ampl, params.gw_Bw, params.gw_Bn, + B0, params.gw_cw, params.gw_cn, params.gw_flag, @@ -227,12 +232,13 @@ Oct_uforcing = CA.non_orographic_gravity_wave_forcing( Oct_u, Oct_bf, Oct_ρ, - copy(center_z), + center_z, source_level, damp_level, params.gw_source_ampl, params.gw_Bw, params.gw_Bn, + B0, params.gw_cw, params.gw_cn, params.gw_flag, From e2b328729ef6e97a75e05a4544a8c121f13881dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E8=90=A7=E6=B6=AF?= Date: Tue, 16 Jul 2024 19:08:54 -0700 Subject: [PATCH 2/2] change function names --- .../non_orographic_gravity_wave.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl b/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl index cd82deacc3..5307e1a409 100644 --- a/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl +++ b/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl @@ -205,7 +205,7 @@ function non_orographic_gravity_wave_tendency!( # source level: the index of the level that is closest to the source height Operators.column_mapreduce!( - reduce_fun1, + min_distance_reduce, source_level_z, ᶜz, ᶜlevel, @@ -214,14 +214,14 @@ function non_orographic_gravity_wave_tendency!( end source_level = source_level_z.:2 - Operators.column_mapreduce!(sign, +, damp_level, ᶜz) + fill!(damp_level, Spaces.nlevels(axes(ᶜz))) elseif model_config isa SphericalModel (; ᶜp) = p.precomputed # source level: the index of the highest level whose pressure is higher than source pressure Operators.column_mapreduce!( - reduce_fun2, + positive_selector_reduce, source_level_z, ᶜp, ᶜlevel, @@ -234,7 +234,7 @@ function non_orographic_gravity_wave_tendency!( # damp level: the index of the lowest level whose pressure is lower than the damp pressure Operators.column_mapreduce!( - reduce_fun3, + negative_selector_reduce, damp_level_z, ᶜp, ᶜlevel, @@ -447,6 +447,6 @@ function calc_intermitency(ρ_source_level, source_ampl, nk, Bsum) return (source_ampl / ρ_source_level / nk) / Bsum end -@inline reduce_fun1(a, b) = ifelse(a[1] < b[1], a, b) -@inline reduce_fun2(a, b) = ifelse(b[1] < 0, a, b) -@inline reduce_fun3(a, b) = ifelse(a[1] > 0, b, a) +@inline min_distance_reduce(a, b) = ifelse(a[1] < b[1], a, b) +@inline positive_selector_reduce(a, b) = ifelse(b[1] <= 0, a, b) +@inline negative_selector_reduce(a, b) = ifelse(a[1] >= 0, b, a)