From e04b890c711758ca4c4b134cd7167f34ccb21983 Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Fri, 17 May 2024 16:48:02 -0700 Subject: [PATCH] Add RMSE leaderboard --- experiments/ClimaEarth/run_amip.jl | 44 ++++++++++++--- .../user_io/leaderboard/compare_with_obs.jl | 56 +++++++++++++++++-- .../ClimaEarth/user_io/leaderboard/utils.jl | 1 + 3 files changed, 90 insertions(+), 11 deletions(-) diff --git a/experiments/ClimaEarth/run_amip.jl b/experiments/ClimaEarth/run_amip.jl index ee13a5f6f..143ac9976 100644 --- a/experiments/ClimaEarth/run_amip.jl +++ b/experiments/ClimaEarth/run_amip.jl @@ -923,19 +923,49 @@ if ClimaComms.iamroot(comms_ctx) include("user_io/leaderboard.jl") compare_vars = ["pr"] - function plot_biases(dates, output_name) + function compute_biases(dates) + if isempty(dates) + return map(x -> 0.0, compare_vars) + else + return Leaderboard.compute_biases(atmos_sim.integrator.p.output_dir, compare_vars, dates) + end + end + + function plot_biases(dates, biases, output_name) + isempty(dates) && return nothing + output_path = joinpath(dir_paths.artifacts, "bias_$(output_name).png") - Leaderboard.plot_biases(atmos_sim.integrator.p.output_dir, compare_vars, dates; output_path) + Leaderboard.plot_biases(biases; output_path) end - plot_biases(output_dates, "total") + + ann_biases = compute_biases(output_dates) + plot_biases(output_dates, ann_biases, "total") ## collect all days between cs.dates.date0 and cs.dates.date MAM, JJA, SON, DJF = Leaderboard.split_by_season(output_dates) - !isempty(MAM) && plot_biases(MAM, "MAM") - !isempty(JJA) && plot_biases(JJA, "JJA") - !isempty(SON) && plot_biases(SON, "SON") - !isempty(DJF) && plot_biases(DJF, "DJF") + MAM_biases = compute_biases(MAM) + plot_biases(MAM, MAM_biases, "MAM") + JJA_biases = compute_biases(JJA) + plot_biases(JJA, JJA_biases, "JJA") + SON_biases = compute_biases(SON) + plot_biases(SON, SON_biases, "SON") + DJF_biases = compute_biases(DJF) + plot_biases(DJF, DJF_biases, "DJF") + + rmses = map( + (index) -> Leaderboard.RMSEs(; + model_name = "CliMA", + ANN = ann_biases[index], + DJF = DJF_biases[index], + JJA = JJA_biases[index], + MAM = MAM_biases[index], + SON = SON_biases[index], + ), + 1:length(compare_vars), + ) + + Leaderboard.plot_leaderboard(rmses; output_path = "bias_leaderboard.png") end end diff --git a/experiments/ClimaEarth/user_io/leaderboard/compare_with_obs.jl b/experiments/ClimaEarth/user_io/leaderboard/compare_with_obs.jl index 2cfca1aad..19e6f9629 100644 --- a/experiments/ClimaEarth/user_io/leaderboard/compare_with_obs.jl +++ b/experiments/ClimaEarth/user_io/leaderboard/compare_with_obs.jl @@ -1,16 +1,36 @@ const OBS_DS = Dict() const SIM_DS_KWARGS = Dict() +const OTHER_MODELS_RMSEs = Dict() function preprocess_pr_fn(data) # -1 kg/m/s2 -> 1 mm/day return data .* Float32(-86400) end +Base.@kwdef struct RMSEs + model_name::String + ANN::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0 + DJF::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0 + JJA::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0 + MAM::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0 + SON::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0 +end + +function Base.values(r::RMSEs) + val_or_rmse(v::Real) = v + val_or_rmse(v::ClimaAnalysis.OutputVar) = v.attributes["rmse"] + + return val_or_rmse.([r.ANN, r.DJF, r.JJA, r.MAM, r.SON]) +end + OBS_DS["pr"] = ObsDataSource(; path = joinpath(pr_obs_data_path(), "gpcp.precip.mon.mean.197901-202305.nc"), var_name = "precip") SIM_DS_KWARGS["pr"] = (; preprocess_data_fn = preprocess_pr_fn, new_units = "mm / day") +# TODO: These numbers are eyeballed and should not be really used. Use instead real values from the various models +OTHER_MODELS_RMSEs["pr"] = [RMSEs(; model_name = "AM4.0", ANN = 0.5, DJF = 1.0, JJA = 1.5, MAM = 0.5, SON = 1.0)] + # OBS_DS["rsut"] = ObsDataSource(; # path = "OBS/CERES_EBAF-TOA_Ed4.2_Subset_200003-202303.g025.nc", # var_name = "toa_sw_all_mon", @@ -27,13 +47,41 @@ function bias(output_dir::AbstractString, short_name::AbstractString, target_dat return bias(obs, sim, target_dates) end -function plot_biases(output_dir, short_names, target_dates::AbstractArray{<:Dates.DateTime}; output_path) - fig = CairoMakie.Figure(; size = (600, 300 * length(short_names))) +function compute_biases(output_dir, short_names, target_dates::AbstractArray{<:Dates.DateTime}) + return map(name -> bias(output_dir, name, target_dates), short_names) +end + +function plot_biases(biases; output_path) + fig = CairoMakie.Figure(; size = (600, 300 * length(biases))) loc = 1 - for short_name in short_names - bias_var = bias(output_dir, short_name, target_dates) + for bias_var in biases ClimaAnalysis.Visualize.heatmap2D_on_globe!(fig, bias_var; p_loc = (1, loc)) loc = loc + 1 end CairoMakie.save(output_path, fig) end + +function plot_leaderboard(rmses; output_path) + fig = CairoMakie.Figure(; size = (600, 300 * length(rmses))) + loc = 1 + + for rmse in rmses + short_name = rmse.ANN.attributes["var_short_name"] + units = rmse.ANN.attributes["units"] + ax = CairoMakie.Axis( + fig[1, loc], + ylabel = "$short_name [$units]", + xticks = (1:5, ["Ann", "DJF", "JJA", "MAM", "SON"]), + title = "Global RMSE", + ) + CairoMakie.scatter!(ax, 1:5, values(rmse), label = rmse.model_name) + for other_model_rmse in OTHER_MODELS_RMSEs[short_name] + CairoMakie.scatter!(ax, 1:5, values(other_model_rmse), label = other_model_rmse.model_name) + end + # Add a fake extra point to center the legend a little better + CairoMakie.scatter!(ax, [6], [0.1], markersize = 0.01) + CairoMakie.axislegend() + loc = loc + 1 + end + CairoMakie.save(output_path, fig) +end diff --git a/experiments/ClimaEarth/user_io/leaderboard/utils.jl b/experiments/ClimaEarth/user_io/leaderboard/utils.jl index 50609703b..85d21f74a 100644 --- a/experiments/ClimaEarth/user_io/leaderboard/utils.jl +++ b/experiments/ClimaEarth/user_io/leaderboard/utils.jl @@ -152,6 +152,7 @@ function bias(obs_ds::ObsDataSource, sim_ds::SimDataSource, target_dates::Abstra bias_attribs = Dict{String, Any}( "short_name" => "sim-obs_$short_name", + "var_short_name" => "$short_name", "long_name" => "SIM - OBS mean $short_name\n(RMSE: $rmse $units, Global bias: $global_bias $units)", "rmse" => rmse, "bias" => global_bias,