Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add destaggering functionality #93

Merged
merged 14 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.7', '3.8', '3.9']
python-version: ['3.8', '3.9', '3.10']
steps:
- uses: actions/checkout@v2
- uses: conda-incubator/setup-miniconda@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/upstream-dev-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.9']
python-version: ['3.10']
steps:
- uses: actions/checkout@v2
- uses: conda-incubator/setup-miniconda@v2
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
'Intended Audience :: Science/Research',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Topic :: Scientific/Engineering',
]

Expand All @@ -29,7 +29,7 @@
description='A lightweight interface for working with the Weather Research and Forecasting (WRF) model output in Xarray.',
long_description=long_description,
long_description_content_type='text/markdown',
python_requires='>=3.7',
python_requires='>=3.8',
maintainer='xWRF Developers',
classifiers=CLASSIFIERS,
url='https://xwrf.readthedocs.io',
Expand Down
38 changes: 38 additions & 0 deletions tests/test_accessors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import pandas as pd
import pytest
import xarray as xr

import xwrf

from . import importorskip


@pytest.fixture(scope='session')
def test_grid(request):
return xwrf.tutorial.open_dataset(request.param)


@importorskip('cf_xarray')
@pytest.mark.parametrize(
'name, cf_grid_mapping_name', [('lambert_conformal', 'lambert_conformal_conic')]
Expand Down Expand Up @@ -59,3 +65,35 @@ def test_postprocess(name, cf_grid_mapping_name):
assert 'PB' not in ds.data_vars
assert 'PH' not in ds.data_vars
assert 'PHB' not in ds.data_vars


@pytest.mark.parametrize('test_grid', ['lambert_conformal', 'mercator'], indirect=True)
def test_dataarray_destagger(test_grid):
data = test_grid['U']
destaggered = data.xwrf.destagger()

# Check shape reduction and dim name adjustment
assert destaggered.sizes['west_east'] == data.sizes['west_east_stag'] - 1

# Check coordinate reduction
xr.testing.assert_allclose(destaggered['XLAT'], test_grid['XLAT'])
jthielen marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize('test_grid', ['lambert_conformal', 'mercator'], indirect=True)
def test_dataset_destagger(test_grid):
destaggered = (
test_grid.isel(Time=slice(0, 2))
.xwrf.postprocess(calculate_diagnostic_variables=False)
.xwrf.destagger()
)

# Check elimination of staggered dims and "stagger" attr
for varname in destaggered.data_vars:
assert not {'x_stag', 'y_stag', 'z_stag'}.intersection(set(destaggered[varname].dims))
assert (
'stagger' not in destaggered[varname].attrs
or destaggered[varname].attrs['stagger'] == ''
)

# Check that attrs are preserved
assert destaggered.attrs == test_grid.attrs
80 changes: 80 additions & 0 deletions tests/test_destagger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from collections import Counter

import numpy as np
import pytest
import xarray as xr

import xwrf
from xwrf.destagger import _destag_variable, _drop_attrs, _rename_staggered_coordinate


@pytest.mark.parametrize(
'input_attrs, output_attrs',
[
({'a': 0, 'b': 1, 'c': 2}, {'a': 0, 'c': 2}),
({'b': 3}, {}),
({'a': 4, 'c': 5}, {'a': 4, 'c': 5}),
(Counter(('a', 'a', 'b', 'c')), {'a': 2, 'c': 1}),
],
)
def test_drop_attrs_successful(input_attrs, output_attrs):
result = _drop_attrs(input_attrs, ('b',))
assert result == output_attrs
assert isinstance(result, dict)


def test_drop_attrs_unsuccessful():
assert _drop_attrs('not a Mapping', ('a',)) is None


@pytest.mark.parametrize(
'input_name, stagger_dim, unstag_dim_name, output_name',
[
(
'bottom_top',
'bottom_top_stag',
None,
'bottom_top',
),
('bottom_top_stag', 'bottom_top_stag', 'z', 'z'),
('bottom_top_stag', 'bottom_top_stag', None, 'bottom_top'),
('XLAT_U', 'west_east', None, 'XLAT'),
('XLONG_V', 'south_north', None, 'XLONG'),
],
)
def test_rename_staggered_coordinate(input_name, stagger_dim, unstag_dim_name, output_name):
assert _rename_staggered_coordinate(input_name, stagger_dim, unstag_dim_name) == output_name


def test_destag_variable_missing_dim():
with pytest.raises(ValueError):
_destag_variable(xr.Variable(('x', 'y'), np.zeros((2, 2))), 'z_stag')


def test_destag_variable_multiple_dims():
with pytest.raises(NotImplementedError):
_destag_variable(xr.Variable(('x_stag', 'y_stag'), np.zeros((2, 2))))


@pytest.mark.parametrize(
'unstag_dim_name, expected_output_dim_name',
[
('z', 'z'),
(None, 'bottom_top'),
],
)
def test_destag_variable_1d(unstag_dim_name, expected_output_dim_name):
staggered = xr.Variable(('bottom_top_stag',), np.arange(5), attrs={'stagger': 'Z'})
output = _destag_variable(staggered, unstag_dim_name=unstag_dim_name)
# Check values
np.testing.assert_array_almost_equal(output.values, 0.5 + np.arange(4))
# Check dim name
assert output.dims[0] == expected_output_dim_name
# Check attrs
assert not output.attrs


def test_destag_variable_2d():
staggered = xr.Variable(('x', 'y_stag'), np.arange(9).reshape(3, 3))
expected = xr.Variable(('x', 'y'), [[0.5, 1.5], [3.5, 4.5], [6.5, 7.5]])
xr.testing.assert_equal(_destag_variable(staggered), expected)
90 changes: 90 additions & 0 deletions xwrf/accessors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations # noqa: F401

from typing import Dict, Optional
jthielen marked this conversation as resolved.
Show resolved Hide resolved

import xarray as xr

from .destagger import _destag_variable, _rename_staggered_coordinate
from .postprocess import (
_assign_coord_to_dim_of_different_name,
_calc_base_diagnostics,
Expand All @@ -27,6 +30,47 @@ def __init__(self, xarray_obj: xr.Dataset | xr.DataArray) -> WRFAccessor:
class WRFDataArrayAccessor(WRFAccessor):
"""Adds a number of WRF specific methods to xarray.DataArray objects."""

def destagger(
self, stagger_dim: str | None = None, unstaggered_dim_name: str | None = None
) -> xr.DataArray:
"""
Destagger a single WRF xarray.DataArray

Parameters
----------
stagger_dim : str, optional
Name of dimension to unstagger. Defaults to guessing based on name (ends in "_stag")
unstaggered_dim_name : str, option
String to which to rename the dimension after destaggering. Example would be
"west_east" for "west_east_stag". By default the dimenions will be renamed the text in
front of "_stag" from the "stagger_dim" field.

Returns
-------
xarray.DataArray
The destaggered DataArray with renamed dimension and adjusted coordinates.
"""
new_variable = _destag_variable(
self.xarray_obj.variable, stagger_dim=stagger_dim, unstag_dim_name=unstaggered_dim_name
)

# Need to recalculate staggered coordinates, as they don't already exist independently
# in a DataArray context
new_coords = {}
for coord_name, coord_data in self.xarray_obj.coords.items():
if set(coord_data.dims).difference(set(new_variable.dims)):
# Has a dimension not in the destaggered output (and so still staggered)
new_name = _rename_staggered_coordinate(
coord_name, stagger_dim=stagger_dim, unstag_dim_name=unstaggered_dim_name
)
new_coords[new_name] = _destag_variable(
coord_data, stagger_dim=stagger_dim, unstag_dim_name=unstaggered_dim_name
)
jthielen marked this conversation as resolved.
Show resolved Hide resolved
else:
new_coords[coord_name] = coord_data.variable

return xr.DataArray(new_variable, coords=new_coords)


@xr.register_dataset_accessor('xwrf')
class WRFDatasetAccessor(WRFAccessor):
Expand Down Expand Up @@ -83,3 +127,49 @@ def postprocess(
ds = ds.pipe(_include_projection_coordinates)

return ds.pipe(_rename_dims)

def destagger(self, staggered_to_unstaggered_dims: dict[str, str] | None = None) -> xr.Dataset:
"""
Destagger all data variables in a WRF xarray.Dataset

Parameters
----------
staggered_to_unstaggered_dims : dict, optional
Mapping of target staggered dimensions to corresponding unstaggered dimensions

Returns
-------
xarray.Dataset
The destaggered dataset.

Notes
-----
Does not destagger coordinates, and instead relies upon grid cell center coordinates
already being present in the dataset.
"""
staggered_dims = (
{dim for dim in self.xarray_obj.dims if dim.endswith('_stag')}
if staggered_to_unstaggered_dims is None
else set(staggered_to_unstaggered_dims)
)
new_data_vars = {}
for var_name, var_data in self.xarray_obj.data_vars.items():
if this_staggered_dims := set(var_data.dims).intersection(staggered_dims):
# Found a staggered dim
# TODO: should we raise an error if somehow end up with more than just one
# staggered dim, or just pick one from the set like below?
this_staggered_dim = this_staggered_dims.pop()
new_data_vars[var_name] = _destag_variable(
var_data.variable,
stagger_dim=this_staggered_dim,
unstag_dim_name=(
None
if staggered_to_unstaggered_dims is None
else staggered_to_unstaggered_dims[this_staggered_dim]
),
)
else:
# No staggered dims
new_data_vars[var_name] = var_data.variable

return xr.Dataset(new_data_vars, self.xarray_obj.coords, self.xarray_obj.attrs)
90 changes: 90 additions & 0 deletions xwrf/destagger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import xarray as xr


def _drop_attrs(attrs_dict, attrs_to_drop):
try:
return {k: v for k, v in attrs_dict.items() if k not in attrs_to_drop}
except AttributeError:
return None


def _destag_variable(datavar, stagger_dim=None, unstag_dim_name=None):
"""
Destaggering function for a single wrf xarray.Variable

Based off of the wrf-python destagger function
https://github.com/NCAR/wrf-python/blob/22fb45c54f5193b849fdff0279445532c1a6c89f/src/wrf/destag.py
Copyright 2016 University Corporation for Atmospheric Research, reused with modification
according to the terms of the Apache License, Version 2.0

Parameters
----------
datavar : xarray.Variable
Data variable to be destaggered
stagger_dim : str, optional
Name of dimension to unstagger. Defaults to guessing based on name (ends in "_stag")
unstag_dim_name : str, option
String to which to rename the dimension after destaggering. Example would be
"west_east" for "west_east_stag". By default the dimenions will be renamed the text in
front of "_stag" from the "stagger_dim" field.

Returns
-------
xarray.Variable
The destaggered variable with renamed dimension
"""
# get the coordinate to unstagger
# option 1) user has provided the dimension
if stagger_dim and stagger_dim not in datavar.dims:
# check that the user-passed in stag dim is actually in there
raise ValueError(f'{stagger_dim} not in {datavar.dims}')

# option 2) guess the staggered dimension
elif stagger_dim is None:
# guess the name of the coordinate
stagger_dim = [x for x in datavar.dims if x.endswith('_stag')]

if len(stagger_dim) > 1:
raise NotImplementedError(
'Expected a single destagger dimensions. Found multiple destagger dimensions: '
f'{stagger_dim}'
)

# we need a string, not a list
stagger_dim = stagger_dim[0]

# get the size of the staggereed coordinate
stagger_dim_size = datavar.sizes[stagger_dim]

# I think the "dict(a="...")" format is preferrable... but you cant stick an fx arg string
# into that...
left_or_bottom_cells = datavar.isel({stagger_dim: slice(0, stagger_dim_size - 1)})
right_or_top_cells = datavar.isel({stagger_dim: slice(1, stagger_dim_size)})
center_mean = (left_or_bottom_cells + right_or_top_cells) * 0.5

# now change the variable name of the unstaggered coordinate
# we can pass this in if we want to, for whatever reason
if unstag_dim_name is None:
unstag_dim_name = stagger_dim.split('_stag')[
0
] # get the part of the name before the "_stag"

# return a data variable with renamed dimensions
return xr.Variable(
dims=tuple(str(unstag_dim_name) if dim == stagger_dim else dim for dim in center_mean.dims),
data=center_mean.data,
attrs=_drop_attrs(center_mean.attrs, ('stagger',)),
encoding=center_mean.encoding,
fastpath=True,
)


def _rename_staggered_coordinate(name, stagger_dim=None, unstag_dim_name=None):
if name == stagger_dim and unstag_dim_name is not None:
return unstag_dim_name
elif name[-2:].lower() in ('_u', '_v'):
return name[:-2]
elif name[-5:].lower() == '_stag':
return name[:-5]
else:
return name