Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RMSE leaderboard #799

Merged
merged 1 commit into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions experiments/ClimaEarth/run_amip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,19 +923,49 @@ if ClimaComms.iamroot(comms_ctx)

include("user_io/leaderboard.jl")
compare_vars = ["pr"]
function plot_biases(dates, output_name)
function compute_biases(dates)
if isempty(dates)
return map(x -> 0.0, compare_vars)
else
return Leaderboard.compute_biases(atmos_sim.integrator.p.output_dir, compare_vars, dates)
end
end

function plot_biases(dates, biases, output_name)
isempty(dates) && return nothing

output_path = joinpath(dir_paths.artifacts, "bias_$(output_name).png")
Leaderboard.plot_biases(atmos_sim.integrator.p.output_dir, compare_vars, dates; output_path)
Leaderboard.plot_biases(biases; output_path)
end
plot_biases(output_dates, "total")

ann_biases = compute_biases(output_dates)
plot_biases(output_dates, ann_biases, "total")

## collect all days between cs.dates.date0 and cs.dates.date
MAM, JJA, SON, DJF = Leaderboard.split_by_season(output_dates)

!isempty(MAM) && plot_biases(MAM, "MAM")
!isempty(JJA) && plot_biases(JJA, "JJA")
!isempty(SON) && plot_biases(SON, "SON")
!isempty(DJF) && plot_biases(DJF, "DJF")
MAM_biases = compute_biases(MAM)
plot_biases(MAM, MAM_biases, "MAM")
JJA_biases = compute_biases(JJA)
plot_biases(JJA, JJA_biases, "JJA")
SON_biases = compute_biases(SON)
plot_biases(SON, SON_biases, "SON")
DJF_biases = compute_biases(DJF)
plot_biases(DJF, DJF_biases, "DJF")

rmses = map(
(index) -> Leaderboard.RMSEs(;
model_name = "CliMA",
ANN = ann_biases[index],
DJF = DJF_biases[index],
JJA = JJA_biases[index],
MAM = MAM_biases[index],
SON = SON_biases[index],
),
1:length(compare_vars),
)

Leaderboard.plot_leaderboard(rmses; output_path = "bias_leaderboard.png")
end
end

Expand Down
56 changes: 52 additions & 4 deletions experiments/ClimaEarth/user_io/leaderboard/compare_with_obs.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
const OBS_DS = Dict()
const SIM_DS_KWARGS = Dict()
const OTHER_MODELS_RMSEs = Dict()

function preprocess_pr_fn(data)
# -1 kg/m/s2 -> 1 mm/day
return data .* Float32(-86400)
end

Base.@kwdef struct RMSEs
model_name::String
ANN::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0
DJF::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0
JJA::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0
MAM::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0
SON::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0
end

function Base.values(r::RMSEs)
val_or_rmse(v::Real) = v
val_or_rmse(v::ClimaAnalysis.OutputVar) = v.attributes["rmse"]

return val_or_rmse.([r.ANN, r.DJF, r.JJA, r.MAM, r.SON])
end

OBS_DS["pr"] =
ObsDataSource(; path = joinpath(pr_obs_data_path(), "gpcp.precip.mon.mean.197901-202305.nc"), var_name = "precip")

SIM_DS_KWARGS["pr"] = (; preprocess_data_fn = preprocess_pr_fn, new_units = "mm / day")

# TODO: These numbers are eyeballed and should not be really used. Use instead real values from the various models
OTHER_MODELS_RMSEs["pr"] = [RMSEs(; model_name = "AM4.0", ANN = 0.5, DJF = 1.0, JJA = 1.5, MAM = 0.5, SON = 1.0)]

Sbozzolo marked this conversation as resolved.
Show resolved Hide resolved
# OBS_DS["rsut"] = ObsDataSource(;
# path = "OBS/CERES_EBAF-TOA_Ed4.2_Subset_200003-202303.g025.nc",
# var_name = "toa_sw_all_mon",
Expand All @@ -27,13 +47,41 @@ function bias(output_dir::AbstractString, short_name::AbstractString, target_dat
return bias(obs, sim, target_dates)
end

function plot_biases(output_dir, short_names, target_dates::AbstractArray{<:Dates.DateTime}; output_path)
fig = CairoMakie.Figure(; size = (600, 300 * length(short_names)))
function compute_biases(output_dir, short_names, target_dates::AbstractArray{<:Dates.DateTime})
return map(name -> bias(output_dir, name, target_dates), short_names)
end

function plot_biases(biases; output_path)
fig = CairoMakie.Figure(; size = (600, 300 * length(biases)))
loc = 1
for short_name in short_names
bias_var = bias(output_dir, short_name, target_dates)
for bias_var in biases
ClimaAnalysis.Visualize.heatmap2D_on_globe!(fig, bias_var; p_loc = (1, loc))
loc = loc + 1
end
CairoMakie.save(output_path, fig)
end

function plot_leaderboard(rmses; output_path)
fig = CairoMakie.Figure(; size = (600, 300 * length(rmses)))
loc = 1

for rmse in rmses
short_name = rmse.ANN.attributes["var_short_name"]
units = rmse.ANN.attributes["units"]
ax = CairoMakie.Axis(
fig[1, loc],
ylabel = "$short_name [$units]",
xticks = (1:5, ["Ann", "DJF", "JJA", "MAM", "SON"]),
title = "Global RMSE",
)
CairoMakie.scatter!(ax, 1:5, values(rmse), label = rmse.model_name)
for other_model_rmse in OTHER_MODELS_RMSEs[short_name]
CairoMakie.scatter!(ax, 1:5, values(other_model_rmse), label = other_model_rmse.model_name)
end
# Add a fake extra point to center the legend a little better
CairoMakie.scatter!(ax, [6], [0.1], markersize = 0.01)
CairoMakie.axislegend()
loc = loc + 1
end
CairoMakie.save(output_path, fig)
end
1 change: 1 addition & 0 deletions experiments/ClimaEarth/user_io/leaderboard/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ function bias(obs_ds::ObsDataSource, sim_ds::SimDataSource, target_dates::Abstra

bias_attribs = Dict{String, Any}(
"short_name" => "sim-obs_$short_name",
"var_short_name" => "$short_name",
"long_name" => "SIM - OBS mean $short_name\n(RMSE: $rmse $units, Global bias: $global_bias $units)",
"rmse" => rmse,
"bias" => global_bias,
Expand Down
Loading