Skip to content

Commit

Permalink
Regrid truncated data
Browse files Browse the repository at this point in the history
update import convention

final edits
  • Loading branch information
anastasia-popova committed Apr 13, 2024
1 parent 5e22e8b commit 0ced3c7
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 4 deletions.
22 changes: 22 additions & 0 deletions artifacts/artifact_funcs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,25 @@ function pr_obs_data_path()
)
return AW.get_data_folder(pr_obs_data)
end

"""
artifact_data(datapath_full, name)
Returns input dataset at datapath_full
"""
function artifact_data(datapath_full, name)
datafile_truncated = joinpath(datapath_full, string(name, ".nc"))
return datafile_truncated
end

"""
artifact_data(datapath_full, name, datapath_trunc, date0, time_start, time_end, comms_ctx)
Truncates given data set, and constructs a new dataset containing only the dates needed and stores it in datapath_trunc
"""
function artifact_data(datapath_full, name, datapath_trunc, date0, time_start, time_end, comms_ctx)
datafile = joinpath(datapath_full, string(name, ".nc"))
datafile_truncated =
Regridder.truncate_dataset(datafile, name, datapath_trunc, date0, time_start, time_end, comms_ctx)
return datafile_truncated
end
8 changes: 4 additions & 4 deletions experiments/AMIP/coupler_driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ original sources.
=#

include(joinpath(pkgdir(ClimaCoupler), "artifacts", "artifact_funcs.jl"))
sst_data = joinpath(sst_dataset_path(), "sst.nc")
sic_data = joinpath(sic_dataset_path(), "sic.nc")
co2_data = joinpath(co2_dataset_path(), "mauna_loa_co2.nc")
land_mask_data = joinpath(mask_dataset_path(), "seamask.nc")
sst_data = artifact_data(sst_dataset_path(), "sst", REGRID_DIR, date0, t_start, t_end, comms_ctx)
sic_data = artifact_data(sic_dataset_path(), "sic", REGRID_DIR, date0, t_start, t_end, comms_ctx)
co2_data = artifact_data(co2_dataset_path(), "mauna_loa_co2", REGRID_DIR, date0, t_start, t_end, comms_ctx)
land_mask_data = artifact_data(mask_dataset_path(), "seamask")

#=
## Component Model Initialization
Expand Down
80 changes: 80 additions & 0 deletions src/Regridder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -636,5 +636,85 @@ function cgll2latlonz(field; DIR = "cgll2latlonz_dir", nlat = 360, nlon = 720, c
return new_data, coords
end

"""
truncate_dataset(datafile, name, datapath_trunc, date0, time_start, time_end, comms_ctx)
Truncates given data set, and constructs a new dataset containing only the dates that are used in the simulation
"""
function truncate_dataset(
datafile,
name,
datapath_trunc,
date0,
time_start,
time_end,
comms_ctx::ClimaComms.AbstractCommsContext,
)
date_start = date0 + Dates.Second(time_start)
date_end = date0 + Dates.Second(time_start + time_end)

file_name = replace(string(name, "_truncated_data_", string(date_start), string(date_end), ".nc"), r":" => "")
datafile_truncated = joinpath(datapath_trunc, file_name)

if ClimaComms.iamroot(comms_ctx)
ds = NCDatasets.NCDataset(datafile, "r")
dates = ds["time"][:]

(start_id, end_id) = find_idx_bounding_dates(dates, date_start, date_end)

ds_truncated = NCDatasets.NCDataset(datafile_truncated, "c")
ds_truncated = NCDatasets.write(ds_truncated, NCDatasets.view(ds, time = start_id:end_id))

close(ds)
close(ds_truncated)

return datafile_truncated
end
end

"""
find_idx_bounding_dates(dates, date_start, date_end)
Returns the index range from dates that contains date_start to date_end
"""
function find_idx_bounding_dates(dates, date_start, date_end)
# if the simulation start date is before our first date in the dataset
# leave the beginning of the truncated dataset to be first date available
if date_start < dates[1]
start_id = 1
# if the simulation start date is after the last date in the dataset
# start the truncated dataset at its last possible date
elseif date_start > last(dates)
start_id = length(dates)
# if the simulation start date falls within the range of the dataset
# find the closest date to the start date and truncate there
else
(~, start_id) = findmin(x -> abs(x - date_start), dates)
# if the closest date is after the start date, add one more date before
if dates[start_id] > date_start
start_id = start_id - 1
end
end

# if the simulation end date is before our first date in the dataset
# truncate the end of the dataset to be the first date
if date_end < dates[1]
end_id = 1
# if the simulation end date is after the last date in the dataset
# leave the end of the dataset as is
elseif date_end > last(dates)
end_id = length(dates)
# if the simulation end date falls within the range of the dataset
# find the closest date to the end date and truncate there
else
(~, end_id) = findmin(x -> abs(x - date_end), dates)
# if the closest date is before the end date, add one more date after
if dates[end_id] < date_end
end_id = end_id + 1
end
end

return (; start_id, end_id)
end

end # Module
72 changes: 72 additions & 0 deletions test/regridder_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Dates
import NCDatasets
import ClimaComms
import ClimaCore as CC
import ClimaCoupler
import ClimaCoupler: Interfacer, Regridder, TestHelper, TimeManager

REGRID_DIR = @isdefined(REGRID_DIR) ? REGRID_DIR : joinpath("", "regrid_tmp/")
Expand Down Expand Up @@ -311,3 +312,74 @@ for FT in (Float32, Float64)
end
end
end
# test dataset truncation
@testset "test dataset truncation" begin
# Get the original dataset set up
include(joinpath(pkgdir(ClimaCoupler), "artifacts", "artifact_funcs.jl"))
sst_data_all = joinpath(sst_dataset_path(), "sst.nc")
ds = NCDatasets.NCDataset(sst_data_all, "r")
dates = ds["time"][:]
first_date = dates[1]
last_date = last(dates)

# set up comms_ctx
device = ClimaComms.device()
comms_ctx = ClimaComms.context(device)
ClimaComms.init(comms_ctx)

# make path for truncated datasets
COUPLER_OUTPUT_DIR = joinpath("experiments", "AMIP", "output", "tests")
mkpath(COUPLER_OUTPUT_DIR)

REGRID_DIR = joinpath(COUPLER_OUTPUT_DIR, "regrid_tmp", "")
mkpath(REGRID_DIR)

# values for the truncations
time_start = 0.0
time_end = 1.728e6
date0test = ["18690101", "18700101", "19790228", "20220301", "20230101"]
for date in date0test
date0 = Dates.DateTime(date, Dates.dateformat"yyyymmdd")
sst_data = Regridder.truncate_dataset(sst_data_all, "test", REGRID_DIR, date0, time_start, time_end, comms_ctx)
ds_truncated = NCDatasets.NCDataset(sst_data, "r")
new_dates = ds_truncated["time"][:]

date_start = date0 + Dates.Second(time_start)
date_end = date0 + Dates.Second(time_start + time_end)

# start date is before the first date of datafile
if date_start < first_date
@test new_dates[1] == first_date
# start date is after the last date in datafile
elseif date_start > last_date
@test new_dates[1] == last_date
# start date is within the bounds of the datafile
else
@test new_dates[1] <= date_start
@test new_dates[2] >= date_start
end

# end date is before the first date of datafile
if date_end < first_date
@test last(new_dates) == first_date
# end date is after the last date of datafile
elseif date_end > last_date
@test last(new_dates) == last_date
# end date is within the bounds of datafile
else
@test last(new_dates) >= date_end
@test new_dates[length(new_dates) - 1] <= date_end
end

# check that truncation is indexing correctly
all_data = ds["SST"][:, :, :]
new_data = ds_truncated["SST"][:, :, :]
(start_id, end_id) = Regridder.find_idx_bounding_dates(dates, date_start, date_end)
@test new_data[:, :, 1] all_data[:, :, start_id]
@test new_data[:, :, length(new_dates)] all_data[:, :, end_id]

close(ds_truncated)
end

close(ds)
end

0 comments on commit 0ced3c7

Please sign in to comment.