Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate inputs #85

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
### Added
- Function to validate input damages. ([PR #85](https://github.com/ClimateImpactLab/dscim/pull/83), [@JMGilbert](https://github.com/JMGilbert))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add this in a new "unreleased" section here and not in the one below?

## [0.4.0] - Unreleased
### Added
- Functions to concatenate input damages across batches. ([PR #83](https://github.com/ClimateImpactLab/dscim/pull/83), [@davidrzhdu](https://github.com/davidrzhdu))
Expand Down
58 changes: 58 additions & 0 deletions src/dscim/preprocessing/input_damages.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def concatenate_damage_output(damage_dir, basename, save_path):
data[v] = data[v].astype("unicode")

data.to_zarr(save_path, mode="w")
validate_damages("energy", save_path)


def calculate_labor_impacts(input_path, file_prefix, variable, val_type):
Expand Down Expand Up @@ -431,6 +432,8 @@ def process_batch(g):
store=save_path, mode="a", consolidated=True
)

validate_damages("agriculture", save_path)


def read_energy_files(df, seed="TINV_clim_price014_total_energy_fulladapt-histclim"):
"""Read energy CSV files and trasnform them to Xarray objects
Expand Down Expand Up @@ -818,6 +821,10 @@ def prep(
for v in data.values():
v.close()
damages.close()
validate_damages(
"mortality",
f"{outpath}/impacts-darwin-montecarlo-damages-v{mortality_version}.zarr",
)


def coastal_inputs(
Expand Down Expand Up @@ -853,6 +860,10 @@ def coastal_inputs(
consolidated=True,
mode="w",
)
validate_damages(
"coastal",
f"{path}/coastal_damages_{version}-{adapt_type}-{vsl_valuation}.zarr",
)
else:
print(
"vsl_valuation is not a dimension of the input dataset, subset adapt_type only"
Expand All @@ -863,3 +874,50 @@ def coastal_inputs(
consolidated=True,
mode="w",
)


def validate_damages(sector, path):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstr? What are the inputs params and what is being checked? Expected behavior?

Be clear about what you're actually checking for. Don't be shy with inline comments, too because people will usually be back to read this kind of checking code in the future. (You did a pretty good job with comments here so 👍 )

inputs = xr.open_zarr(path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why have IO here instead of reading in a zarr store? Would this make it so you don't need to monkey patch it for tests?

inputs.close()

# No repeated batch labels
batches_expected = np.sort(["batch" + str(i) for i in np.arange(0, 15)])
batches_actual = np.sort(inputs.batch.values)
assert np.array_equal(
batches_expected, batches_actual
), f"Batches in the {sector} input damages zarr are not 0-14."

# Input damages have rcp 4.5 and rcp 8.5
if "coastal" not in sector:
rcps_expected = np.sort(["rcp" + str(i) for i in [45, 85]])
rcps_actual = np.sort(inputs.rcp.values)
assert np.array_equal(
rcps_expected, rcps_actual
), f"RCPs in the {sector} input damages zarr are not rcp45 and rcp85."

# max batches and no repeated batches
regions = inputs.dims["region"]
ssps = inputs.dims["ssp"]
if "coastal" in sector:
dims = ["ssp", "model", "slr", "batch", "year", "region"]
chunk_sizes = [1, 1, 1, 15, 10, regions]
total_sizes = [ssps, 2, 10, 15, 90, regions]
else:
dims = ["ssp", "rcp", "model", "gcm", "batch", "year", "region"]
chunk_sizes = [1, 1, 1, 1, 15, 10, regions]
total_sizes = [ssps, 2, 2, 33, 15, 90, regions]

chunk_len = np.arange(0, len(chunk_sizes))
chunks = [
(chunk_sizes[i],) * int(total_sizes[i] / chunk_sizes[i]) for i in chunk_len
]
dims_expected = dict(zip(dims, total_sizes))
chunks_expected = dict(zip(dims, chunks))

assert dims_expected == dict(inputs.dims)
for i in list(inputs.keys()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be able to just have this as inputs.keys() without casting it to a list?

assert (
chunks_expected["batch"] == dict(inputs[i].chunksizes)["batch"]
), f"Chunksize for batches need to equal 15 for the {sector} input damages."
if chunks_expected != dict(inputs[i].chunksizes):
warnings.warn("Non fatal: chunk sizes are different from expected.")
3 changes: 3 additions & 0 deletions src/dscim/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import xarray as xr
from dask.distributed import Client, progress
from dscim.utils.functions import ce_func, mean_func
from dscim.preprocessing.input_damages import validate_damages
import yaml
import time
import argparse
Expand Down Expand Up @@ -96,6 +97,8 @@ def reduce_damages(
delta = params["delta"]
outpath = f"{c['paths']['reduced_damages_library']}/{sector}"

validate_damages(sector, damages)

with xr.open_zarr(damages, chunks=None)[histclim] as ds:
with xr.open_zarr(socioec, chunks=None) as gdppc:
assert (
Expand Down
190 changes: 189 additions & 1 deletion tests/test_input_damages.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings
import numpy as np
import xarray as xr
import pandas as pd
Expand All @@ -23,7 +24,9 @@
calculate_energy_damages,
prep_mortality_damages,
coastal_inputs,
validate_damages,
)
import dscim

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,10 +76,14 @@ def test_parse_projection_filesys(tmp_path):
pd.testing.assert_frame_equal(df_out_expected, df_out_actual)


def test_concatenate_damage_output(tmp_path):
def test_concatenate_damage_output(tmp_path, monkeypatch):
"""
Test that concatenate_damage_output correctly concatenates damages across batches and saves to a single zarr file
"""
monkeypatch.setattr(
"dscim.preprocessing.input_damages.validate_damages", lambda *args: True
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(also applies to all monkeypatching below)

Why does this need to be monkey patched out?

Might be better if the validation function takes a dataset or dataarray as input instead of doing its own IO?

Patching like this in tests might be a symptom of a design problem.

d = os.path.join(tmp_path, "concatenate_in")
if not os.path.exists(d):
os.makedirs(d)
Expand Down Expand Up @@ -429,10 +436,16 @@ def test_calculate_labor_damages(
def test_compute_ag_damages(
tmp_path,
econvars_fixture,
monkeypatch,
):
"""
Test that compute_ag_damages correctly reshapes ag estimate runs for use in integration system and saves to zarr file
"""

monkeypatch.setattr(
"dscim.preprocessing.input_damages.validate_damages", lambda *args: True
)

rcp = ["rcp45", "rcp85"]
gcm = ["ACCESS1-0", "GFDL-CM3"]
model = ["low", "high"]
Expand Down Expand Up @@ -1000,10 +1013,16 @@ def test_prep_mortality_damages(
tmp_path,
version_test,
econvars_fixture,
monkeypatch,
):
"""
Test that prep_mortality_damages correctly reshapes different versions of mortality estimate runs for use in integration system and saves to zarr file
"""

monkeypatch.setattr(
"dscim.preprocessing.input_damages.validate_damages", lambda *args: True
)

for b in ["6", "9"]:
ds_in = xr.Dataset(
{
Expand Down Expand Up @@ -1151,10 +1170,16 @@ def test_error_prep_mortality_damages(tmp_path):
def test_coastal_inputs(
tmp_path,
version_test,
monkeypatch,
):
"""
Test that coastal_inputs correctly reshapes different versions of coastal results for use in integration system and saves to zarr file (v0.21 and v0.22 have exactly the same structure, so testing either one should be sufficient)
"""

monkeypatch.setattr(
"dscim.preprocessing.input_damages.validate_damages", lambda *args: True
)

if version_test == "v0.21":
ds_in = xr.Dataset(
{
Expand Down Expand Up @@ -1356,3 +1381,166 @@ def test_error_coastal_inputs(
str(excinfo.value)
== "vsl_valuation is a coordinate in the input dataset but is set to None. Please provide a value for vsl_valuation by which to subset the input dataset."
)


def create_dummy_input_zarr(path, sector, file_type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs docstr.

# Create dummy input data
batch_values = np.sort(["batch" + str(i) for i in np.arange(0, 15)])
rcp_values = np.sort(["rcp" + str(i) for i in [45, 85]])
ssp_values = np.arange(0, 2)
model_values = np.arange(0, 2)
if sector == "coastal":
slr_values = np.arange(0, 10)
else:
slr_values = np.arange(0, 33)
year_values = np.arange(0, 90)
region_values = np.arange(0, 6)

if file_type == "wrong_rcps":
# Create input data with wrong rcps
rcp_values = np.sort(["rcp" + str(i) for i in [45, 65, 85]])
elif file_type == "wrong_batches":
# Create input data with wrong batches
batch_values = np.sort(
["batch" + str(i) for i in np.arange(0, 14)]
+ [
"batch1",
]
)

if sector == "coastal":
data = np.ones(
(
len(ssp_values),
len(model_values),
len(slr_values),
len(batch_values),
len(year_values),
len(region_values),
)
)
else:
data = np.ones(
(
len(ssp_values),
len(rcp_values),
len(model_values),
len(slr_values),
len(batch_values),
len(year_values),
len(region_values),
)
)

# Create xarray dataset
if "coastal" in sector:
dims = ["ssp", "model", "slr", "batch", "year", "region"]
coords = {
"ssp": (["ssp"], ssp_values),
"model": (["model"], model_values),
"slr": (["slr"], slr_values),
"batch": (["batch"], batch_values),
"year": (["year"], year_values),
"region": (["region"], region_values),
}
chunkies = {
"ssp": 1,
"model": 1,
"slr": 1,
"batch": 5 if file_type == "wrong_chunk_sizes" else -1,
"year": 10,
"region": 3 if file_type == "wrong_region_chunk_sizes" else 6,
}
else:
dims = ["ssp", "rcp", "model", "gcm", "batch", "year", "region"]
coords = {
"ssp": (["ssp"], ssp_values),
"rcp": (["rcp"], rcp_values),
"model": (["model"], model_values),
"gcm": (["gcm"], slr_values),
"batch": (["batch"], batch_values),
"year": (["year"], year_values),
"region": (["region"], region_values),
}
chunkies = {
"ssp": 1,
"rcp": 1,
"model": 1,
"gcm": 1,
"batch": 5 if file_type == "wrong_chunk_sizes" else -1,
"year": 10,
"region": 3 if file_type == "wrong_region_chunk_sizes" else 6,
}

ds = xr.Dataset(
{
"data": (
dims,
data,
),
},
coords=coords,
).chunk(chunkies)

# Save xarray dataset as Zarr
ds.to_zarr(path, mode="w")


@pytest.mark.parametrize("sector", ["mortality", "coastal"])
def test_validate_damages_correct(tmp_path, sector):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need docstr. What behavior is this testing for?

path = str(tmp_path / f"damages_correct_{sector}.zarr")
file_type = "correct"
create_dummy_input_zarr(path, sector, file_type)
validate_damages(sector, path) # No assertion error should be raised


def test_validate_damages_incorrect_batches(tmp_path):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need docstr. What's this testing. I can kinda get a feel for it from the title but might want to say more specifically what you're looking for.

These docstrs also can print when tests fail so its good to have even if they seem obvious.

sector = "mortality"
path = str(tmp_path / f"damages_incorrect_batches_{sector}.zarr")
file_type = "wrong_batches"
create_dummy_input_zarr(path, sector, file_type)
with pytest.raises(AssertionError) as e_info:
validate_damages(sector, path)
assert (
str(e_info.value) == f"Batches in the {sector} input damages zarr are not 0-14."
)


def test_validate_damages_incorrect_rcps(tmp_path):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need docstr. Harder to grok what this is testing for exactly.

sector = "mortality"
path = str(tmp_path / f"damages_incorrect_rcps_{sector}.zarr")
file_type = "wrong_rcps"
create_dummy_input_zarr(path, sector, file_type)
with pytest.raises(AssertionError) as e_info:
validate_damages(sector, path)
assert (
str(e_info.value)
== f"RCPs in the {sector} input damages zarr are not rcp45 and rcp85."
)


@pytest.mark.parametrize("sector", ["mortality", "coastal"])
def test_validate_damages_incorrect_chunk_sizes(tmp_path, sector):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs docstr.

path = str(tmp_path / f"damages_incorrect_chunk_sizes_{sector}.zarr")
file_type = "wrong_chunk_sizes"
create_dummy_input_zarr(path, sector, file_type)
with pytest.raises(AssertionError) as e_info:
validate_damages(sector, path)
assert (
str(e_info.value)
== f"Chunksize for batches need to equal 15 for the {sector} input damages."
)


@pytest.mark.parametrize("sector", ["mortality", "coastal"])
def test_validate_damages_incorrect_region_chunk_sizes(tmp_path, sector):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neeeeds doooocstrr.

path = str(tmp_path / f"damages_incorrect_region_chunk_sizes_{sector}.zarr")
file_type = "wrong_region_chunk_sizes"
create_dummy_input_zarr(path, sector, file_type)
with pytest.warns(UserWarning) as warnings_info:
validate_damages(sector, path)
assert len(warnings_info) == 1
assert (
str(warnings_info[0].message)
== "Non fatal: chunk sizes are different from expected."
)
Loading