From 0ced3c795abf60179428d4a8cc0da1902feac7ac Mon Sep 17 00:00:00 2001 From: Anastasia Popova Date: Mon, 1 Apr 2024 15:00:37 -0400 Subject: [PATCH] Regrid truncated data update import convention final edits --- artifacts/artifact_funcs.jl | 22 ++++++++ experiments/AMIP/coupler_driver.jl | 8 +-- src/Regridder.jl | 80 ++++++++++++++++++++++++++++++ test/regridder_tests.jl | 72 +++++++++++++++++++++++++++ 4 files changed, 178 insertions(+), 4 deletions(-) diff --git a/artifacts/artifact_funcs.jl b/artifacts/artifact_funcs.jl index 93eac7cf82..4e5341f68d 100644 --- a/artifacts/artifact_funcs.jl +++ b/artifacts/artifact_funcs.jl @@ -59,3 +59,25 @@ function pr_obs_data_path() ) return AW.get_data_folder(pr_obs_data) end + +""" + artifact_data(datapath_full, name) + +Returns input dataset at datapath_full +""" +function artifact_data(datapath_full, name) + datafile_truncated = joinpath(datapath_full, string(name, ".nc")) + return datafile_truncated +end + +""" + artifact_data(datapath_full, name, datapath_trunc, date0, time_start, time_end, comms_ctx) + +Truncates given data set, and constructs a new dataset containing only the dates needed and stores it in datapath_trunc +""" +function artifact_data(datapath_full, name, datapath_trunc, date0, time_start, time_end, comms_ctx) + datafile = joinpath(datapath_full, string(name, ".nc")) + datafile_truncated = + Regridder.truncate_dataset(datafile, name, datapath_trunc, date0, time_start, time_end, comms_ctx) + return datafile_truncated +end diff --git a/experiments/AMIP/coupler_driver.jl b/experiments/AMIP/coupler_driver.jl index 830a1f7b86..0e555629d2 100644 --- a/experiments/AMIP/coupler_driver.jl +++ b/experiments/AMIP/coupler_driver.jl @@ -168,10 +168,10 @@ original sources. =# include(joinpath(pkgdir(ClimaCoupler), "artifacts", "artifact_funcs.jl")) -sst_data = joinpath(sst_dataset_path(), "sst.nc") -sic_data = joinpath(sic_dataset_path(), "sic.nc") -co2_data = joinpath(co2_dataset_path(), "mauna_loa_co2.nc") -land_mask_data = joinpath(mask_dataset_path(), "seamask.nc") +sst_data = artifact_data(sst_dataset_path(), "sst", REGRID_DIR, date0, t_start, t_end, comms_ctx) +sic_data = artifact_data(sic_dataset_path(), "sic", REGRID_DIR, date0, t_start, t_end, comms_ctx) +co2_data = artifact_data(co2_dataset_path(), "mauna_loa_co2", REGRID_DIR, date0, t_start, t_end, comms_ctx) +land_mask_data = artifact_data(mask_dataset_path(), "seamask") #= ## Component Model Initialization diff --git a/src/Regridder.jl b/src/Regridder.jl index 9aab3ee53d..4a09a63c0d 100644 --- a/src/Regridder.jl +++ b/src/Regridder.jl @@ -636,5 +636,85 @@ function cgll2latlonz(field; DIR = "cgll2latlonz_dir", nlat = 360, nlon = 720, c return new_data, coords end +""" + truncate_dataset(datafile, name, datapath_trunc, date0, time_start, time_end, comms_ctx) + +Truncates given data set, and constructs a new dataset containing only the dates that are used in the simulation +""" +function truncate_dataset( + datafile, + name, + datapath_trunc, + date0, + time_start, + time_end, + comms_ctx::ClimaComms.AbstractCommsContext, +) + date_start = date0 + Dates.Second(time_start) + date_end = date0 + Dates.Second(time_start + time_end) + + file_name = replace(string(name, "_truncated_data_", string(date_start), string(date_end), ".nc"), r":" => "") + datafile_truncated = joinpath(datapath_trunc, file_name) + + if ClimaComms.iamroot(comms_ctx) + ds = NCDatasets.NCDataset(datafile, "r") + dates = ds["time"][:] + + (start_id, end_id) = find_idx_bounding_dates(dates, date_start, date_end) + + ds_truncated = NCDatasets.NCDataset(datafile_truncated, "c") + ds_truncated = NCDatasets.write(ds_truncated, NCDatasets.view(ds, time = start_id:end_id)) + + close(ds) + close(ds_truncated) + + return datafile_truncated + end +end + +""" + find_idx_bounding_dates(dates, date_start, date_end) + +Returns the index range from dates that contains date_start to date_end +""" +function find_idx_bounding_dates(dates, date_start, date_end) + # if the simulation start date is before our first date in the dataset + # leave the beginning of the truncated dataset to be first date available + if date_start < dates[1] + start_id = 1 + # if the simulation start date is after the last date in the dataset + # start the truncated dataset at its last possible date + elseif date_start > last(dates) + start_id = length(dates) + # if the simulation start date falls within the range of the dataset + # find the closest date to the start date and truncate there + else + (~, start_id) = findmin(x -> abs(x - date_start), dates) + # if the closest date is after the start date, add one more date before + if dates[start_id] > date_start + start_id = start_id - 1 + end + end + + # if the simulation end date is before our first date in the dataset + # truncate the end of the dataset to be the first date + if date_end < dates[1] + end_id = 1 + # if the simulation end date is after the last date in the dataset + # leave the end of the dataset as is + elseif date_end > last(dates) + end_id = length(dates) + # if the simulation end date falls within the range of the dataset + # find the closest date to the end date and truncate there + else + (~, end_id) = findmin(x -> abs(x - date_end), dates) + # if the closest date is before the end date, add one more date after + if dates[end_id] < date_end + end_id = end_id + 1 + end + end + + return (; start_id, end_id) +end end # Module diff --git a/test/regridder_tests.jl b/test/regridder_tests.jl index cb0029aad5..0001e7a632 100644 --- a/test/regridder_tests.jl +++ b/test/regridder_tests.jl @@ -6,6 +6,7 @@ import Dates import NCDatasets import ClimaComms import ClimaCore as CC +import ClimaCoupler import ClimaCoupler: Interfacer, Regridder, TestHelper, TimeManager REGRID_DIR = @isdefined(REGRID_DIR) ? REGRID_DIR : joinpath("", "regrid_tmp/") @@ -311,3 +312,74 @@ for FT in (Float32, Float64) end end end +# test dataset truncation +@testset "test dataset truncation" begin + # Get the original dataset set up + include(joinpath(pkgdir(ClimaCoupler), "artifacts", "artifact_funcs.jl")) + sst_data_all = joinpath(sst_dataset_path(), "sst.nc") + ds = NCDatasets.NCDataset(sst_data_all, "r") + dates = ds["time"][:] + first_date = dates[1] + last_date = last(dates) + + # set up comms_ctx + device = ClimaComms.device() + comms_ctx = ClimaComms.context(device) + ClimaComms.init(comms_ctx) + + # make path for truncated datasets + COUPLER_OUTPUT_DIR = joinpath("experiments", "AMIP", "output", "tests") + mkpath(COUPLER_OUTPUT_DIR) + + REGRID_DIR = joinpath(COUPLER_OUTPUT_DIR, "regrid_tmp", "") + mkpath(REGRID_DIR) + + # values for the truncations + time_start = 0.0 + time_end = 1.728e6 + date0test = ["18690101", "18700101", "19790228", "20220301", "20230101"] + for date in date0test + date0 = Dates.DateTime(date, Dates.dateformat"yyyymmdd") + sst_data = Regridder.truncate_dataset(sst_data_all, "test", REGRID_DIR, date0, time_start, time_end, comms_ctx) + ds_truncated = NCDatasets.NCDataset(sst_data, "r") + new_dates = ds_truncated["time"][:] + + date_start = date0 + Dates.Second(time_start) + date_end = date0 + Dates.Second(time_start + time_end) + + # start date is before the first date of datafile + if date_start < first_date + @test new_dates[1] == first_date + # start date is after the last date in datafile + elseif date_start > last_date + @test new_dates[1] == last_date + # start date is within the bounds of the datafile + else + @test new_dates[1] <= date_start + @test new_dates[2] >= date_start + end + + # end date is before the first date of datafile + if date_end < first_date + @test last(new_dates) == first_date + # end date is after the last date of datafile + elseif date_end > last_date + @test last(new_dates) == last_date + # end date is within the bounds of datafile + else + @test last(new_dates) >= date_end + @test new_dates[length(new_dates) - 1] <= date_end + end + + # check that truncation is indexing correctly + all_data = ds["SST"][:, :, :] + new_data = ds_truncated["SST"][:, :, :] + (start_id, end_id) = Regridder.find_idx_bounding_dates(dates, date_start, date_end) + @test new_data[:, :, 1] ≈ all_data[:, :, start_id] + @test new_data[:, :, length(new_dates)] ≈ all_data[:, :, end_id] + + close(ds_truncated) + end + + close(ds) +end