Skip to content

Commit

Permalink
Fix 236 (#247)
Browse files Browse the repository at this point in the history
* Subset gridpoint - simplify for cases with lon, lat dims. Add checks for irregular grids

* Add checks to make sure input data has 'lon' and 'lat' attributes

* Add raise exception test

* modify checks on lon, lat inputs.  allow time-only selection

* split repeated code into subroutine: subset_time

* subset_bbox adjustments & tests

* longitude adjust as separate func

* update history
  • Loading branch information
tlogan2000 authored Jul 9, 2019
1 parent 4acf09f commit e20f6ea
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 74 deletions.
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ History
* Development build configurations are now available via both Anaconda and pip install methods.
* Modified create_ensembles() to allow creation of ensemble dataset without a time dimension as well as from xr.Datasets
* Modified create ensembles() to pad input data with nans when time dimensions are unequal
* Updated subset_gridpoint() and subset_bbox() to use .sel method if 'lon' and 'lat' dims are present.

0.10-beta (2019-06-06)
----------------------
Expand Down
47 changes: 47 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,14 +586,53 @@ def test_simple(self):
np.testing.assert_array_equal(out.time.dt.year.max(), yr_ed)
np.testing.assert_array_equal(out.time.dt.year.min(), yr_st)

# test time only
out = subset.subset_gridpoint(da, start_yr=yr_st, end_yr=yr_ed)
np.testing.assert_array_equal(len(np.unique(out.time.dt.year)), 10)
np.testing.assert_array_equal(out.time.dt.year.max(), yr_ed)
np.testing.assert_array_equal(out.time.dt.year.min(), yr_st)

def test_irregular(self):

da = xr.open_dataset(self.nc_2dlonlat).tasmax
lon = -72.4
lat = 46.1
out = subset.subset_gridpoint(da, lon=lon, lat=lat)
np.testing.assert_almost_equal(out.lon, lon, 1)
np.testing.assert_almost_equal(out.lat, lat, 1)

# test_irregular transposed:
da1 = xr.open_dataset(self.nc_2dlonlat).tasmax
dims = list(da1.dims)
dims.reverse()
daT = xr.DataArray(np.transpose(da1.values), dims=dims)
for d in daT.dims:
args = dict()
args[d] = da1[d]
daT = daT.assign_coords(**args)
daT = daT.assign_coords(lon=(["rlon", "rlat"], np.transpose(da1.lon.values)))
daT = daT.assign_coords(lat=(["rlon", "rlat"], np.transpose(da1.lat.values)))

out1 = subset.subset_gridpoint(daT, lon=lon, lat=lat)
np.testing.assert_almost_equal(out1.lon, lon, 1)
np.testing.assert_almost_equal(out1.lat, lat, 1)
np.testing.assert_array_equal(out, out1)

# Dataset with tasmax, lon and lat as data variables (i.e. lon, lat not coords of tasmax)
daT1 = xr.DataArray(np.transpose(da1.values), dims=dims)
for d in daT1.dims:
args = dict()
args[d] = da1[d]
daT1 = daT1.assign_coords(**args)
dsT = xr.Dataset(data_vars=None, coords=daT1.coords)
dsT["tasmax"] = daT1
dsT["lon"] = xr.DataArray(np.transpose(da1.lon.values), dims=["rlon", "rlat"])
dsT["lat"] = xr.DataArray(np.transpose(da1.lat.values), dims=["rlon", "rlat"])
out2 = subset.subset_gridpoint(dsT, lon=lon, lat=lat)
np.testing.assert_almost_equal(out2.lon, lon, 1)
np.testing.assert_almost_equal(out2.lat, lat, 1)
np.testing.assert_array_equal(out, out2.tasmax)

def test_positive_lons(self):
da = xr.open_dataset(self.nc_poslons).tas
lon = -72.4
Expand All @@ -611,6 +650,10 @@ def test_raise(self):
with pytest.raises(ValueError):
subset.subset_gridpoint(da, lon=-72.4, lat=46.1, start_yr=2056, end_yr=2055)

da = xr.open_dataset(self.nc_2dlonlat).tasmax.drop(["lon", "lat"])
with pytest.raises(Exception):
subset.subset_gridpoint(da, lon=-72.4, lat=46.1)


class TestSubsetBbox:
nc_poslons = os.path.join(
Expand Down Expand Up @@ -710,6 +753,10 @@ def test_raise(self):
da, lon_bnds=self.lon, lat_bnds=self.lat, start_yr=2056, end_yr=2055
)

da = xr.open_dataset(self.nc_2dlonlat).tasmax.drop(["lon", "lat"])
with pytest.raises(Exception):
subset.subset_bbox(da, lon_bnds=self.lon, lat_bnds=self.lat)


class TestThresholdCount:
def test_simple(self, tas_series):
Expand Down
220 changes: 146 additions & 74 deletions xclim/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,42 +40,54 @@ def subset_bbox(da, lon_bnds=None, lat_bnds=None, start_yr=None, end_yr=None):
>>> ds = xr.open_mfdataset(['pr.day.nc','tas.day.nc'])
>>> dsSub = subset.subset_bbox(ds,lon_bnds=[-75,-70],lat_bnds=[40,45],start_yr=1990,end_yr=1999)
"""

if lon_bnds is not None:
lon_bnds = np.asarray(lon_bnds)
if np.all(da.lon > 0) and np.any(lon_bnds < 0):
lon_bnds[lon_bnds < 0] += 360
if np.all(da.lon < 0) and np.any(lon_bnds > 0):
lon_bnds[lon_bnds < 0] -= 360
da = da.where(
(da.lon >= lon_bnds.min()) & (da.lon <= lon_bnds.max()), drop=True
)

if lat_bnds is not None:
lat_bnds = np.asarray(lat_bnds)
da = da.where(
(da.lat >= lat_bnds.min()) & (da.lat <= lat_bnds.max()), drop=True
)
# check if trying to subset lon and lat

if not lat_bnds is None or not lon_bnds is None:
if hasattr(da, "lon") and hasattr(da, "lat"):
if lon_bnds is None:
lon_bnds = [da.lon.min(), da.lon.max()]

lon_bnds = _check_lons(da, np.asarray(lon_bnds))

lon_cond = (da.lon >= lon_bnds.min()) & (da.lon <= lon_bnds.max())

if lat_bnds is None:
lat_bnds = [da.lat.min(), da.lat.max()]

lat_bnds = np.asarray(lat_bnds)
lat_cond = (da.lat >= lat_bnds.min()) & (da.lat <= lat_bnds.max())
dims = list(da.dims)

if "lon" in dims and "lat" in dims:
da = da.sel(lon=lon_cond, lat=lat_cond)
else:
ind = np.where(lon_cond & lat_cond)
dims_lonlat = da.lon.dims
# reduce size using isel
args = {}
for d in dims_lonlat:
coords = da[d][ind[dims_lonlat.index(d)]]
args[d] = slice(coords.min(), coords.max())
da = da.sel(**args)
lon_cond = (da.lon >= lon_bnds.min()) & (da.lon <= lon_bnds.max())
lat_cond = (da.lat >= lat_bnds.min()) & (da.lat <= lat_bnds.max())

# mask irregular grid with new lat lon conditions
da = da.where(lon_cond & lat_cond, drop=True)
else:
raise (
Exception(
'subset_bbox() requires input data with "lon" and "lat" dimensions, coordinates or data variables.'
)
)

if start_yr or end_yr:
if not start_yr:
start_yr = da.time.dt.year.min()
if not end_yr:
end_yr = da.time.dt.year.max()

if start_yr > end_yr:
raise ValueError("Start date is after end date.")

year_bnds = np.asarray([start_yr, end_yr])
da = da.where(
(da.time.dt.year >= year_bnds.min()) & (da.time.dt.year <= year_bnds.max()),
drop=True,
)
da = subset_time(da, start_yr=start_yr, end_yr=end_yr)

return da


def subset_gridpoint(da, lon, lat, start_yr=None, end_yr=None):
def subset_gridpoint(da, lon=None, lat=None, start_yr=None, end_yr=None):
"""Extract a nearest gridpoint from datarray based on lat lon coordinate.
Time series can optionally be subsetted by year(s)
Expand Down Expand Up @@ -114,53 +126,113 @@ def subset_gridpoint(da, lon, lat, start_yr=None, end_yr=None):
>>> dsSub = subset.subset_gridpoint(ds, lon=-75,lat=45,start_yr=1990,end_yr=1999)
"""

g = Geod(ellps="WGS84") # WGS84 ellipsoid - decent globaly
# adjust negative/positive longitudes if necessary
if np.all(da.lon > 0) and lon < 0:
lon += 360
if np.all(da.lon < 0) and lon > 0:
lon -= 360

if len(da.lon.shape) == 1 & len(da.lat.shape) == 1:
# create a 2d grid of lon, lat values
lon1, lat1 = np.meshgrid(np.asarray(da.lon.values), np.asarray(da.lat.values))
# check if trying to subset lon and lat
if not lat is None and not lon is None:
# make sure input data has 'lon' and 'lat'(dims, coordinates, or data_vars)
if hasattr(da, "lon") and hasattr(da, "lat"):
# adjust negative/positive longitudes if necessary
lon = _check_lons(da, lon)

dims = list(da.dims)

# if 'lon' and 'lat' are present as data dimensions use the .sel method.
if "lat" in dims and "lon" in dims:
da = da.sel(lat=lat, lon=lon, method="nearest")
else:
g = Geod(ellps="WGS84") # WGS84 ellipsoid - decent globaly
lon1 = da.lon.values
lat1 = da.lat.values
shp_orig = lon1.shape
lon1 = np.reshape(lon1, lon1.size)
lat1 = np.reshape(lat1, lat1.size)
# calculate geodesic distance between grid points and point of interest
az12, az21, dist = g.inv(
lon1,
lat1,
np.broadcast_to(lon, lon1.shape),
np.broadcast_to(lat, lat1.shape),
)
dist = dist.reshape(shp_orig)
iy, ix = np.unravel_index(np.argmin(dist, axis=None), dist.shape)
xydims = [x for x in da.lon.dims]
args = dict()
args[xydims[0]] = iy
args[xydims[1]] = ix
da = da.isel(**args)
else:
raise (
Exception(
'subset_gridpoint() requires input data with "lon" and "lat" coordinates or data variables.'
)
)

else:
lon1 = da.lon.values
lat1 = da.lat.values
shp_orig = lon1.shape
lon1 = np.reshape(lon1, lon1.size)
lat1 = np.reshape(lat1, lat1.size)
# calculate geodesic distance between grid points and point of interest
az12, az21, dist = g.inv(
lon1, lat1, np.broadcast_to(lon, lon1.shape), np.broadcast_to(lat, lat1.shape)
)
dist = dist.reshape(shp_orig)

iy, ix = np.unravel_index(np.argmin(dist, axis=None), dist.shape)
xydims = [x for x in da.dims if "time" not in x]

args = dict()
args[xydims[0]] = iy
args[xydims[1]] = ix
out = da.isel(**args)
if start_yr or end_yr:
if not start_yr:
start_yr = da.time.dt.year.min()
if not end_yr:
end_yr = da.time.dt.year.max()
da = subset_time(da, start_yr=start_yr, end_yr=end_yr)

if start_yr > end_yr:
raise ValueError("Start date is after end date.")
return da

year_bnds = np.asarray([start_yr, end_yr])

if len(year_bnds) == 1:
time_cond = da.time.dt.year == year_bnds
else:
time_cond = (da.time.dt.year >= year_bnds.min()) & (
da.time.dt.year <= year_bnds.max()
)
out = out.where(time_cond, drop=True)
def subset_time(da, start_yr=None, end_yr=None):
"""Subset input data based on start and end years
return out
Return a subsetted data array (or dataset) for years falling
within provided year bounds
Parameters
----------
da : xarray.DataArray or xarray.DataSet
Input data.
start_yr : int
First year of the subset. Defaults to first year of input.
end_yr : int
Last year of the subset. Defaults to last year of input.
Returns
-------
xarray.DataArray or xarray.DataSet
Subsetted data array or dataset
Examples
--------
>>> from xclim import subset
>>> ds = xr.open_dataset('pr.day.nc')
Subset multiple years
>>> prSub = subset.subset_time(ds.pr,start_yr=1990,end_yr=1999)
Subset single year
>>> prSub = subset.subset_time(ds.pr,start_yr=1990,end_yr=1990)
Subset multiple variables in a single dataset
>>> ds = xr.open_mfdataset(['pr.day.nc','tas.day.nc'])
>>> dsSub = subset.subset_time(ds,start_yr=1990,end_yr=1999)
"""

if not start_yr:
start_yr = da.time.dt.year.min()
if not end_yr:
end_yr = da.time.dt.year.max()

if start_yr > end_yr:
raise ValueError("Start date is after end date.")

year_bnds = np.asarray([start_yr, end_yr])

if len(year_bnds) == 1:
time_cond = da.time.dt.year == year_bnds
else:
time_cond = (da.time.dt.year >= year_bnds.min()) & (
da.time.dt.year <= year_bnds.max()
)
return da.sel(time=time_cond)


def _check_lons(da, lon_bnds):
if np.all(da.lon > 0) and np.any(lon_bnds < 0):
if isinstance(lon_bnds, float):
lon_bnds += 360
else:
lon_bnds[lon_bnds < 0] += 360
if np.all(da.lon < 0) and np.any(lon_bnds > 0):
if isinstance(lon_bnds, float):
lon_bnds -= 360
else:
lon_bnds[lon_bnds < 0] -= 360
return lon_bnds

0 comments on commit e20f6ea

Please sign in to comment.