Skip to content

Commit

Permalink
Add new tests for Python frontend (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
sehnem authored Apr 9, 2024
1 parent f328bba commit b914c51
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pyrte_rrtmgp/kernels/rrtmgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
rrtmgp_compute_tau_rayleigh,
rrtmgp_interpolation,
)
from pyrte_rrtmgp.utils import convert_xarray_args


@convert_xarray_args
def interpolation(
neta: int,
flavor: npt.NDArray,
Expand Down Expand Up @@ -111,6 +113,7 @@ def interpolation(
return jtemp.T, fmajor.T, fminor.T, col_mix.T, tropo.T, jeta.T, jpress.T


@convert_xarray_args
def compute_planck_source(
tlay,
tlev,
Expand Down Expand Up @@ -208,6 +211,7 @@ def compute_planck_source(
return sfc_src.T, lay_src.T, lev_src.T, sfc_src_jac.T


@convert_xarray_args
def compute_tau_absorption(
idx_h2o,
gpoint_flavor,
Expand Down Expand Up @@ -337,6 +341,7 @@ def compute_tau_absorption(
return tau.T


@convert_xarray_args
def compute_tau_rayleigh(
gpoint_flavor,
band_lims_gpt,
Expand Down
4 changes: 4 additions & 0 deletions pyrte_rrtmgp/rrtmgp_gas_optics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
@dataclass
class GasOptics:
tau: Optional[np.ndarray] = None
tau_rayleigh: Optional[np.ndarray] = None
tau_absorption: Optional[np.ndarray] = None
g: Optional[np.ndarray] = None
ssa: Optional[np.ndarray] = None
lay_src: Optional[np.ndarray] = None
Expand Down Expand Up @@ -310,6 +312,7 @@ def compute_gas_taus(self):
self._interpolated.jpress,
)

self.gas_optics.tau_absorption = tau_absorption
if self.source_is_internal:
self.gas_optics.tau = tau_absorption
self.gas_optics.ssa = np.full_like(tau_absorption, np.nan)
Expand All @@ -332,6 +335,7 @@ def compute_gas_taus(self):
self._interpolated.jtemp,
)

self.gas_optics.tau_rayleigh = tau_rayleigh
self.gas_optics.tau = tau_absorption + tau_rayleigh
self.gas_optics.ssa = np.where(
self.gas_optics.tau > 2.0 * np.finfo(float).tiny,
Expand Down
19 changes: 19 additions & 0 deletions pyrte_rrtmgp/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import xarray as xr


def get_usecols(solar_zenith_angle):
Expand Down Expand Up @@ -41,3 +42,21 @@ def compute_toa_flux(total_solar_irradiance, solar_source):
toa_flux = np.stack([solar_source] * ncol)
def_tsi = toa_flux.sum(axis=1)
return (toa_flux.T * (total_solar_irradiance / def_tsi)).T


def convert_xarray_args(func):
def wrapper(*args, **kwargs):
output_args = []
for x in args:
if isinstance(x, xr.DataArray):
output_args.append(x.data)
else:
output_args.append(x)
for k, v in kwargs.items():
if isinstance(v, xr.DataArray):
kwargs[k] = v.data
else:
kwargs[k] = v
return func(*output_args, **kwargs)

return wrapper
193 changes: 193 additions & 0 deletions tests/test_python_frontend/test_gas_optics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import os

import numpy as np
import pytest
import xarray as xr
from pyrte_rrtmgp import rrtmgp_gas_optics
from pyrte_rrtmgp.kernels.rrtmgp import (
compute_planck_source,
compute_tau_absorption,
compute_tau_rayleigh,
interpolation,
)

from utils import convert_args_arrays

ERROR_TOLERANCE = 1e-4

rte_rrtmgp_dir = os.environ.get("RRTMGP_DATA", "rrtmgp-data")
clear_sky_example_files = f"{rte_rrtmgp_dir}/examples/rfmip-clear-sky/inputs"

rfmip = xr.load_dataset(
f"{clear_sky_example_files}/multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc"
)
rfmip = rfmip.sel(expt=0) # only one experiment
kdist = xr.load_dataset(f"{rte_rrtmgp_dir}/rrtmgp-gas-lw-g256.nc")
kdist_sw = xr.load_dataset(f"{rte_rrtmgp_dir}/rrtmgp-gas-sw-g224.nc")

rrtmgp_gas_optics = kdist.gas_optics.load_atmosferic_conditions(rfmip)
rrtmgp_gas_optics_sw = kdist_sw.gas_optics.load_atmosferic_conditions(rfmip)

# Prepare the arguments for the interpolation function
interpolation_args = [
len(kdist["mixing_fraction"]),
kdist.gas_optics.flavors_sets,
kdist["press_ref"].values,
kdist["temp_ref"].values,
kdist["press_ref_trop"].values.item(),
kdist.gas_optics.vmr_ref,
rfmip["pres_layer"].values,
rfmip["temp_layer"].values,
kdist.gas_optics.col_gas,
]

expected_output = (
kdist.gas_optics._interpolated.jtemp,
kdist.gas_optics._interpolated.fmajor,
kdist.gas_optics._interpolated.fminor,
kdist.gas_optics._interpolated.col_mix,
kdist.gas_optics._interpolated.tropo,
kdist.gas_optics._interpolated.jeta,
kdist.gas_optics._interpolated.jpress,
)


@pytest.mark.parametrize(
"args, expected",
[(i, expected_output) for i in convert_args_arrays(interpolation_args)],
)
def test_compute_interpoaltion(args, expected):
result = interpolation(*args)
assert len(result) == len(expected)
for r, e in zip(result, expected):
assert r.shape == e.shape
assert np.isclose(r, e, atol=ERROR_TOLERANCE).all()


# Prepare the arguments for the compute_planck_source function
planck_source_args = [
rfmip["temp_layer"].data,
rfmip["temp_level"].data,
rfmip["surface_temperature"].data,
kdist.gas_optics.top_at_1,
kdist.gas_optics._interpolated.fmajor,
kdist.gas_optics._interpolated.jeta,
kdist.gas_optics._interpolated.tropo,
kdist.gas_optics._interpolated.jtemp,
kdist.gas_optics._interpolated.jpress,
kdist["bnd_limits_gpt"].data.T,
kdist["plank_fraction"].data.transpose(0, 2, 1, 3),
kdist["temp_ref"].data.min(),
kdist["temp_ref"].data.max(),
kdist["totplnk"].data.T,
kdist.gas_optics.gpoint_flavor,
]

expected_output = (
rrtmgp_gas_optics.sfc_src,
rrtmgp_gas_optics.lay_src,
rrtmgp_gas_optics.lev_src,
rrtmgp_gas_optics.sfc_src_jac,
)


@pytest.mark.parametrize(
"args, expected",
[(i, expected_output) for i in convert_args_arrays(planck_source_args)],
)
def test_compute_planck_source(args, expected):
result = compute_planck_source(*args)
assert len(result) == len(expected)
for r, e in zip(result, expected):
assert r.shape == e.shape
assert np.isclose(r, e, atol=ERROR_TOLERANCE).all()


# Prepare the arguments for the compute_tau_absorption function
minor_gases_lower = kdist.gas_optics.extract_names(kdist["minor_gases_lower"].data)
minor_gases_upper = kdist.gas_optics.extract_names(kdist["minor_gases_upper"].data)
idx_minor_lower = kdist.gas_optics.get_idx_minor(
kdist.gas_optics.gas_names, minor_gases_lower
)
idx_minor_upper = kdist.gas_optics.get_idx_minor(
kdist.gas_optics.gas_names, minor_gases_upper
)

scaling_gas_lower = kdist.gas_optics.extract_names(kdist["scaling_gas_lower"].data)
scaling_gas_upper = kdist.gas_optics.extract_names(kdist["scaling_gas_upper"].data)
idx_minor_scaling_lower = kdist.gas_optics.get_idx_minor(
kdist.gas_optics.gas_names, scaling_gas_lower
)
idx_minor_scaling_upper = kdist.gas_optics.get_idx_minor(
kdist.gas_optics.gas_names, scaling_gas_upper
)

tau_absorption_args = [
kdist.gas_optics.idx_h2o,
kdist.gas_optics.gpoint_flavor,
kdist["bnd_limits_gpt"].values.T,
kdist["kmajor"].values,
kdist["kminor_lower"].values,
kdist["kminor_upper"].values,
kdist["minor_limits_gpt_lower"].values.T,
kdist["minor_limits_gpt_upper"].values.T,
kdist["minor_scales_with_density_lower"].values.astype(bool),
kdist["minor_scales_with_density_upper"].values.astype(bool),
kdist["scale_by_complement_lower"].values.astype(bool),
kdist["scale_by_complement_upper"].values.astype(bool),
idx_minor_lower,
idx_minor_upper,
idx_minor_scaling_lower,
idx_minor_scaling_upper,
kdist["kminor_start_lower"].values,
kdist["kminor_start_upper"].values,
kdist.gas_optics._interpolated.tropo,
kdist.gas_optics._interpolated.col_mix,
kdist.gas_optics._interpolated.fmajor,
kdist.gas_optics._interpolated.fminor,
rfmip["pres_layer"].values,
rfmip["temp_layer"].values,
kdist.gas_optics.col_gas,
kdist.gas_optics._interpolated.jeta,
kdist.gas_optics._interpolated.jtemp,
kdist.gas_optics._interpolated.jpress,
]


@pytest.mark.parametrize(
"args, expected",
[
(i, rrtmgp_gas_optics.tau_absorption)
for i in convert_args_arrays(tau_absorption_args)
],
)
def test_compute_tau_absorption(args, expected):
result = compute_tau_absorption(*args)
assert np.isclose(result, expected, atol=ERROR_TOLERANCE).all()


# Prepare the arguments for the compute_tau_rayleigh function
tau_rayleigh_args = [
kdist_sw.gas_optics.gpoint_flavor,
kdist_sw["bnd_limits_gpt"].values.T,
np.stack([kdist_sw["rayl_lower"].values, kdist_sw["rayl_upper"].values], axis=-1),
kdist_sw.gas_optics.idx_h2o,
kdist_sw.gas_optics.col_gas[:, :, 0],
kdist_sw.gas_optics.col_gas,
kdist_sw.gas_optics._interpolated.fminor,
kdist_sw.gas_optics._interpolated.jeta,
kdist_sw.gas_optics._interpolated.tropo,
kdist_sw.gas_optics._interpolated.jtemp,
]


@pytest.mark.parametrize(
"args, expected",
[
(i, rrtmgp_gas_optics_sw.tau_rayleigh)
for i in convert_args_arrays(tau_rayleigh_args)
],
)
def test_compute_tau_rayleigh(args, expected):
result = compute_tau_rayleigh(*args)
assert np.isclose(result, expected, atol=ERROR_TOLERANCE).all()
24 changes: 24 additions & 0 deletions tests/test_python_frontend/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np
import xarray as xr


def convert_args_arrays(input_args, arrays_dtypes=[np.float64, np.float32]):
args_to_test = []
for dtype in arrays_dtypes:
args = []
for item in input_args:
if isinstance(item, np.ndarray) and item.dtype in arrays_dtypes:
output_item = item.astype(dtype)
else:
output_item = item
args.append(output_item)
args_to_test.append(args)
args = []
for item in input_args:
if isinstance(item, np.ndarray) and item.dtype in arrays_dtypes:
output_item = xr.DataArray(item)
else:
output_item = item
args.append(output_item)
args_to_test.append(args)
return args_to_test

0 comments on commit b914c51

Please sign in to comment.