Skip to content

Commit

Permalink
Merge pull request #346 from djhoese/feature-dynamic-area-def-dask
Browse files Browse the repository at this point in the history
Add better dask handling to DynamicAreaDefinitions
  • Loading branch information
djhoese authored Apr 14, 2021
2 parents 15fab19 + 81bc10b commit f71a474
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 46 deletions.
36 changes: 35 additions & 1 deletion pyresample/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,9 +980,14 @@ def freeze(self, lonslats=None, resolution=None, shape=None, proj_info=None):
area_extent, self.rotation)

def _compute_bound_centers(self, proj_dict, lonslats):
proj4 = Proj(proj_dict)
lons, lats = self._extract_lons_lats(lonslats)
if hasattr(lons, 'compute'):
return self._compute_bound_centers_dask(proj_dict, lons, lats)
return self._compute_bound_centers_numpy(proj_dict, lons, lats)

def _compute_bound_centers_numpy(self, proj_dict, lons, lats):
# TODO: Do more dask-friendly things here
proj4 = Proj(proj_dict)
xarr, yarr = proj4(np.asarray(lons), np.asarray(lats))
xarr[xarr > 9e29] = np.nan
yarr[yarr > 9e29] = np.nan
Expand All @@ -999,6 +1004,35 @@ def _compute_bound_centers(self, proj_dict, lonslats):
xmax = np.nanmax(xarr[xarr < 0]) + 360
return xmin, ymin, xmax, ymax

def _compute_bound_centers_dask(self, proj_dict, lons, lats):
from pyresample.utils.proj4 import DaskFriendlyTransformer
import dask.array as da
crs = CRS(proj_dict)
transformer = DaskFriendlyTransformer.from_crs(CRS(4326), crs,
always_xy=True)
xarr, yarr = transformer.transform(lons, lats)
xarr = da.where(xarr > 9e29, np.nan, xarr)
yarr = da.where(yarr > 9e29, np.nan, yarr)
_xmin = np.nanmin(xarr)
_xmax = np.nanmax(xarr)
_ymin = np.nanmin(yarr)
_ymax = np.nanmax(yarr)
xmin, xmax, ymin, ymax = da.compute(
_xmin,
_xmax,
_ymin,
_ymax)

x_passes_antimeridian = (xmax - xmin) > 355
epsilon = 0.1
y_is_pole = (ymax >= 90 - epsilon) or (ymin <= -90 + epsilon)
if crs.is_geographic and x_passes_antimeridian and not y_is_pole:
# cross anti-meridian of projection
xmin = np.nanmin(xarr[xarr >= 0])
xmax = np.nanmax(xarr[xarr < 0]) + 360
xmin, xmax = da.compute(xmin, xmax)
return xmin, ymin, xmax, ymax

def _extract_lons_lats(self, lonslats):
try:
lons, lats = lonslats
Expand Down
90 changes: 46 additions & 44 deletions pyresample/test/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
import unittest
import pyproj

try:
from pyproj import CRS
except ImportError:
CRS = None
from pyproj import CRS

import dask.array as da
import xarray as xr


class Test(unittest.TestCase):
Expand Down Expand Up @@ -417,44 +417,33 @@ def test_swath_hash(self):

self.assertIsInstance(hash(swath_def), int)

try:
import dask.array as da
except ImportError:
print("Not testing with dask arrays")
else:
dalons = da.from_array(lons, chunks=1000)
dalats = da.from_array(lats, chunks=1000)
swath_def = geometry.SwathDefinition(dalons, dalats)

self.assertIsInstance(hash(swath_def), int)

try:
import xarray as xr
except ImportError:
print("Not testing with xarray")
else:
xrlons = xr.DataArray(lons)
xrlats = xr.DataArray(lats)
swath_def = geometry.SwathDefinition(xrlons, xrlats)

self.assertIsInstance(hash(swath_def), int)

try:
import xarray as xr
import dask.array as da
except ImportError:
print("Not testing with xarrays and dask arrays")
else:
xrlons = xr.DataArray(da.from_array(lons, chunks=1000))
xrlats = xr.DataArray(da.from_array(lats, chunks=1000))
swath_def = geometry.SwathDefinition(xrlons, xrlats)

self.assertIsInstance(hash(swath_def), int)
def test_swath_hash_dask(self):
"""Test hashing SwathDefinitions with dask arrays underneath."""
lons = np.array([1.2, 1.3, 1.4, 1.5])
lats = np.array([65.9, 65.86, 65.82, 65.78])
dalons = da.from_array(lons, chunks=1000)
dalats = da.from_array(lats, chunks=1000)
swath_def = geometry.SwathDefinition(dalons, dalats)
self.assertIsInstance(hash(swath_def), int)

lons = np.ma.array([1.2, 1.3, 1.4, 1.5])
lats = np.ma.array([65.9, 65.86, 65.82, 65.78])
swath_def = geometry.SwathDefinition(lons, lats)
def test_swath_hash_xarray(self):
"""Test hashing SwathDefinitions with DataArrays underneath."""
lons = np.array([1.2, 1.3, 1.4, 1.5])
lats = np.array([65.9, 65.86, 65.82, 65.78])
xrlons = xr.DataArray(lons)
xrlats = xr.DataArray(lats)
swath_def = geometry.SwathDefinition(xrlons, xrlats)
self.assertIsInstance(hash(swath_def), int)

def test_swath_hash_xarray_with_dask(self):
"""Test hashing SwathDefinitions with DataArrays:dask underneath."""
lons = np.array([1.2, 1.3, 1.4, 1.5])
lats = np.array([65.9, 65.86, 65.82, 65.78])
dalons = da.from_array(lons, chunks=1000)
dalats = da.from_array(lats, chunks=1000)
xrlons = xr.DataArray(dalons)
xrlats = xr.DataArray(dalats)
swath_def = geometry.SwathDefinition(xrlons, xrlats)
self.assertIsInstance(hash(swath_def), int)

def test_area_equal(self):
Expand Down Expand Up @@ -2290,16 +2279,29 @@ def test_freeze(self):
(np.linspace(-75, -90.0, 10),),
],
)
def test_freeze_longlat_antimeridian(self, lats):
@pytest.mark.parametrize('use_dask', [False, True])
def test_freeze_longlat_antimeridian(self, lats, use_dask):
"""Test geographic areas over the antimeridian."""
import dask
from pyresample.test.utils import CustomScheduler
area = geometry.DynamicAreaDefinition('test_area', 'A test area',
'EPSG:4326')
lons = np.linspace(175, 185, 10)
lons[lons > 180] -= 360
result = area.freeze((lons, lats),
resolution=0.0056)

is_pole = (np.abs(lats) > 88).any()
if use_dask:
# if we aren't at a pole then we adjust the coordinates
# that takes a total of 2 computations
num_computes = 1 if is_pole else 2
lons = da.from_array(lons)
lats = da.from_array(lats)
with dask.config.set(scheduler=CustomScheduler(num_computes)):
result = area.freeze((lons, lats),
resolution=0.0056)
else:
result = area.freeze((lons, lats),
resolution=0.0056)

extent = result.area_extent
if is_pole:
assert extent[0] < -178
Expand Down
47 changes: 46 additions & 1 deletion pyresample/utils/proj4.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import math
from collections import OrderedDict

from pyproj import CRS
import numpy as np
from pyproj import CRS, Transformer as PROJTransformer


def convert_proj_floats(proj_pairs):
Expand Down Expand Up @@ -91,3 +92,47 @@ def get_geostationary_height(geos_area_crs):
params = geos_area_crs.coordinate_operation.params
h_param = [p for p in params if 'satellite height' in p.name.lower()][0]
return h_param.value


def _transform_dask_chunk(x, y, crs_from, crs_to, kwargs, transform_kwargs):
crs_from = CRS(crs_from)
crs_to = CRS(crs_to)
transformer = PROJTransformer.from_crs(crs_from, crs_to, **kwargs)
return np.stack(transformer.transform(x, y, **transform_kwargs), axis=-1)


class DaskFriendlyTransformer:
"""Wrapper around the pyproj Transformer class that uses dask."""

def __init__(self, src_crs, dst_crs, **kwargs):
"""Initialize the transformer with CRS objects.
This method should not be used directly, just like pyproj.Transformer
should not be created directly.
"""
self.src_crs = src_crs
self.dst_crs = dst_crs
self.kwargs = kwargs

@classmethod
def from_crs(cls, crs_from, crs_to, **kwargs):
"""Create transformer object from two CRS objects."""
return cls(crs_from, crs_to, **kwargs)

def transform(self, x, y, **kwargs):
"""Transform coordinates."""
import dask.array as da
crs_from = self.src_crs
crs_to = self.dst_crs
# CRS objects aren't thread-safe until pyproj 3.1+
# convert to WKT strings to be safe
result = da.map_blocks(_transform_dask_chunk, x, y,
crs_from.to_wkt(), crs_to.to_wkt(),
dtype=x.dtype, chunks=x.chunks + ((2,),),
kwargs=self.kwargs,
transform_kwargs=kwargs,
new_axis=x.ndim)
x = result[..., 0]
y = result[..., 1]
return x, y

0 comments on commit f71a474

Please sign in to comment.