From a3552e53e4731c104674dda7826844b9eca1e3f9 Mon Sep 17 00:00:00 2001 From: naschmitz Date: Tue, 9 Apr 2024 14:24:54 -0700 Subject: [PATCH] Fix support for computed images by removing a hack to do fast slicing. Add a new `fast_time_slicing` parameter. If True, Xee performs an optimization that makes slicing an ImageCollection across time faster. This optimization loads EE images in a slice by ID, so any modifications to images in a computed ImageCollection will not be reflected. For those familiar with the code before, the else flow in `_slice_collection` was only entered when images in the collection didn't have IDs. Clearing the image IDs triggered the else block. Also adds several new warnings: - if a user enables `fast_time_slicing` but there are no image IDs, and - if a user is indexing into a very large ImageCollection. Fixes #88 and #145. PiperOrigin-RevId: 623280839 --- xee/ext.py | 39 ++++++++++++++++++++++++++++++++----- xee/ext_integration_test.py | 32 ++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/xee/ext.py b/xee/ext.py index 9ca80f5..6a87097 100644 --- a/xee/ext.py +++ b/xee/ext.py @@ -22,6 +22,7 @@ import functools import importlib import itertools +import logging import math import os import sys @@ -72,6 +73,12 @@ # trial & error. REQUEST_BYTE_LIMIT = 2**20 * 48 # 48 MBs +# Xee uses the ee.ImageCollection.toList function for slicing into an +# ImageCollection. This function isn't optimized for large collections. If the +# end index of the slice is beyond 10k, display a warning to the user. This +# value was chosen by trial and error. +_TO_LIST_WARNING_LIMIT = 10000 + def _check_request_limit(chunks: Dict[str, int], dtype_size: int, limit: int): """Checks that the actual number of bytes exceeds the limit.""" @@ -153,6 +160,7 @@ def open( ee_init_if_necessary: bool = False, executor_kwargs: Optional[Dict[str, Any]] = None, getitem_kwargs: Optional[Dict[str, int]] = None, + fast_time_slicing: bool = False, ) -> 'EarthEngineStore': if mode != 'r': raise ValueError( @@ -175,6 +183,7 @@ def open( ee_init_if_necessary=ee_init_if_necessary, executor_kwargs=executor_kwargs, getitem_kwargs=getitem_kwargs, + fast_time_slicing=fast_time_slicing, ) def __init__( @@ -194,9 +203,11 @@ def __init__( ee_init_if_necessary: bool = False, executor_kwargs: Optional[Dict[str, Any]] = None, getitem_kwargs: Optional[Dict[str, int]] = None, + fast_time_slicing: bool = False, ): self.ee_init_kwargs = ee_init_kwargs self.ee_init_if_necessary = ee_init_if_necessary + self.fast_time_slicing = fast_time_slicing # Initialize executor_kwargs if executor_kwargs is None: @@ -834,15 +845,27 @@ def _slice_collection(self, image_slice: slice) -> ee.Image: self._ee_init_check() start, stop, stride = image_slice.indices(self.shape[0]) - # If the input images have IDs, just slice them. Otherwise, we need to do - # an expensive `toList()` operation. - if self.store.image_ids: + if self.store.fast_time_slicing and self.store.image_ids: imgs = self.store.image_ids[start:stop:stride] else: + if self.store.fast_time_slicing: + logging.warning( + "fast_time_slicing is enabled but ImageCollection images don't have" + ' IDs. Reverting to default behavior.' + ) + if stop > _TO_LIST_WARNING_LIMIT: + logging.warning( + 'Xee is indexing into the ImageCollection beyond %s images. This' + ' operation can be slow. To improve performance, consider filtering' + ' the ImageCollection prior to using Xee or enabling' + ' fast_time_slicing.', + _TO_LIST_WARNING_LIMIT, + ) # TODO(alxr, mahrsee): Find a way to make this case more efficient. list_range = stop - start - col0 = self.store.image_collection - imgs = col0.toList(list_range, offset=start).slice(0, list_range, stride) + imgs = self.store.image_collection.toList(list_range, offset=start).slice( + 0, list_range, stride + ) col = ee.ImageCollection(imgs) @@ -1006,6 +1029,7 @@ def open_dataset( ee_init_kwargs: Optional[Dict[str, Any]] = None, executor_kwargs: Optional[Dict[str, Any]] = None, getitem_kwargs: Optional[Dict[str, int]] = None, + fast_time_slicing: bool = False, ) -> xarray.Dataset: # type: ignore """Open an Earth Engine ImageCollection as an Xarray Dataset. @@ -1084,6 +1108,10 @@ def open_dataset( - 'max_retries', the maximum number of retry attempts. Defaults to 6. - 'initial_delay', the initial delay in milliseconds before the first retry. Defaults to 500. + fast_time_slicing (optional): Whether to perform an optimization that + makes slicing an ImageCollection across time faster. This optimization + loads EE images in a slice by ID, so any modifications to images in a + computed ImageCollection will not be reflected. Returns: An xarray.Dataset that streams in remote data from Earth Engine. """ @@ -1114,6 +1142,7 @@ def open_dataset( ee_init_if_necessary=ee_init_if_necessary, executor_kwargs=executor_kwargs, getitem_kwargs=getitem_kwargs, + fast_time_slicing=fast_time_slicing, ) store_entrypoint = backends_store.StoreBackendEntrypoint() diff --git a/xee/ext_integration_test.py b/xee/ext_integration_test.py index e2d7466..143da6c 100644 --- a/xee/ext_integration_test.py +++ b/xee/ext_integration_test.py @@ -514,6 +514,38 @@ def test_validate_band_attrs(self): for _, value in variable.attrs.items(): self.assertIsInstance(value, valid_types) + def test_fast_time_slicing(self): + band = 'temperature_2m' + hourly = ( + ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY') + .filterDate('2024-01-01', '2024-01-02') + .select(band) + ) + first = hourly.first() + props = ['system:id', 'system:time_start'] + fake_collection = ee.ImageCollection( + hourly.toList(count=hourly.size()).replace( + first, ee.Image(0).rename(band).copyProperties(first, props) + ) + ) + + params = dict( + filename_or_obj=fake_collection, + engine=xee.EarthEngineBackendEntrypoint, + geometry=ee.Geometry.BBox(-83.86, 41.13, -76.83, 46.15), + projection=first.projection().atScale(100000), + ) + + # With slow slicing, the returned data should include the modified image. + slow_slicing = xr.open_dataset(**params) + slow_slicing_data = getattr(slow_slicing[dict(time=0)], band).as_numpy() + self.assertTrue(np.all(slow_slicing_data == 0)) + + # With fast slicing, the returned data should include the original image. + fast_slicing = xr.open_dataset(**params, fast_time_slicing=True) + fast_slicing_data = getattr(fast_slicing[dict(time=0)], band).as_numpy() + self.assertTrue(np.all(fast_slicing_data > 0)) + @absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded') def test_write_projected_dataset_to_raster(self): # ensure that a projected dataset written to a raster intersects with the