Skip to content

Commit

Permalink
Fix geogrid chunking to accept auto
Browse files Browse the repository at this point in the history
  • Loading branch information
mraspaud committed Nov 15, 2023
1 parent 11e0519 commit 13cd934
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 23 deletions.
44 changes: 21 additions & 23 deletions geotiepoints/interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,6 @@ def _interp(self):
if np.array_equal(self.hrow_indices, self.row_indices):
return self._interp1d()

xpoints, ypoints = np.meshgrid(self.hrow_indices,
self.hcol_indices)

for num, data in enumerate(self.tie_data):
spl = RectBivariateSpline(self.row_indices,
self.col_indices,
Expand All @@ -221,8 +218,7 @@ def _interp(self):
kx=self.kx_,
ky=self.ky_)

new_data_ = spl.ev(xpoints.ravel(), ypoints.ravel())
self.new_data[num] = new_data_.reshape(xpoints.shape).T.copy(order='C')
self.new_data[num] = spl(self.hrow_indices, self.hcol_indices, grid=True)

def _interp1d(self):
"""Interpolate in one dimension."""
Expand Down Expand Up @@ -279,38 +275,40 @@ def interpolate_dask(self, fine_points, method, chunks):
"""Interpolate (lazily) to a dask array."""
from dask.base import tokenize
import dask.array as da
from dask.array.core import normalize_chunks
v_fine_points, h_fine_points = fine_points
shape = len(v_fine_points), len(h_fine_points)

try:
v_chunk_size, h_chunk_size = chunks
except TypeError:
v_chunk_size, h_chunk_size = chunks, chunks

vchunks = range(0, shape[0], v_chunk_size)
hchunks = range(0, shape[1], h_chunk_size)
chunks = normalize_chunks(chunks, shape, dtype=self.values.dtype)

token = tokenize(v_chunk_size, h_chunk_size, self.points, self.values, fine_points, method)
token = tokenize(chunks, self.points, self.values, fine_points, method)
name = 'interpolate-' + token

dskx = {(name, i, j): (self.interpolate_slices,
(slice(vcs, min(vcs + v_chunk_size, shape[0])),
slice(hcs, min(hcs + h_chunk_size, shape[1]))),
method
)
for i, vcs in enumerate(vchunks)
for j, hcs in enumerate(hchunks)
}
def _enumerate_chunk_slices(chunks):
"""Enumerate chunks with slices."""
for position in np.ndindex(tuple(map(len, (chunks)))):
slices = []
for pos, chunk in zip(position, chunks):
chunk_size = chunk[pos]
offset = sum(chunk[:pos])
slices.append(slice(offset, offset + chunk_size))

yield (position, slices)

dskx = {(name, ) + position: (self.interpolate_slices,
slices,
method)
for position, slices in _enumerate_chunk_slices(chunks)}

res = da.Array(dskx, name, shape=list(shape),
chunks=(v_chunk_size, h_chunk_size),
chunks=chunks,
dtype=self.values.dtype)
return res

def interpolate_numpy(self, fine_points, method="linear"):
"""Interpolate to a numpy array."""
fine_x, fine_y = np.meshgrid(*fine_points, indexing='ij')
return self.interpolator((fine_x, fine_y), method=method)
return self.interpolator((fine_x, fine_y), method=method).astype(self.values.dtype)

def interpolate_slices(self, fine_points, method="linear"):
"""Interpolate using slices.
Expand Down
34 changes: 34 additions & 0 deletions geotiepoints/tests/test_geointerpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,40 @@ def test_geogrid_interpolation_to_shape(self):
np.testing.assert_allclose(lons[0, :], lons_expected, rtol=5e-5)
np.testing.assert_allclose(lats[:, 0], lats_expected, rtol=5e-5)

def test_geogrid_interpolation_preserves_dtype(self):
"""Test that the interpolator works with both explicit tie-point arrays and swath definition objects."""
x_points = np.array([0, 1, 3, 7])
y_points = np.array([0, 1, 3, 7, 15])

interpolator = GeoGridInterpolator((y_points, x_points),
TIE_LONS.astype(np.float32), TIE_LATS.astype(np.float32))

lons, lats = interpolator.interpolate_to_shape((16, 8))

assert lons.dtype == np.float32
assert lats.dtype == np.float32

def test_chunked_geogrid_interpolation(self):
"""Test that the interpolator works with both explicit tie-point arrays and swath definition objects."""
import dask

x_points = np.array([0, 1, 3, 7])
y_points = np.array([0, 1, 3, 7, 15])

interpolator = GeoGridInterpolator((y_points, x_points),
TIE_LONS.astype(np.float32), TIE_LATS.astype(np.float32))

lons, lats = interpolator.interpolate_to_shape((16, 8), chunks=4)

assert lons.chunks == ((4, 4, 4, 4), (4, 4))
assert lats.chunks == ((4, 4, 4, 4), (4, 4))

with dask.config.set({"array.chunk-size": 64}):

lons, lats = interpolator.interpolate_to_shape((16, 8), chunks="auto")
assert lons.chunks == ((4, 4, 4, 4), (4, 4))
assert lats.chunks == ((4, 4, 4, 4), (4, 4))

def test_geogrid_interpolation_can_extrapolate(self):
"""Test that the interpolator can also extrapolate given the right parameters."""
x_points = np.array([0, 1, 3, 7])
Expand Down

0 comments on commit 13cd934

Please sign in to comment.