Skip to content

Commit

Permalink
support @lru_cache on load_collection with DriverVectorCube in LoadPa…
Browse files Browse the repository at this point in the history
  • Loading branch information
bossie committed Sep 23, 2022
1 parent 69b98f0 commit dd05a21
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 43 deletions.
5 changes: 2 additions & 3 deletions openeo_driver/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from openeo_driver.users import User
from openeo_driver.users.oidc import OidcProvider
from openeo_driver.utils import read_json, dict_item, EvalEnv, extract_namedtuple_fields_from_dict, \
get_package_versions, EvalEnvEncoder
get_package_versions

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -165,8 +165,7 @@ def copy(self) -> "LoadParameters":
return LoadParameters(super().copy())

def __hash__(self) -> int:
return hash(json.dumps(self, sort_keys=True,cls=EvalEnvEncoder))

return 0 # poorly hashable but load_collection's lru_cache is small anyway


class AbstractCollectionCatalog(MicroService, metaclass=abc.ABCMeta):
Expand Down
5 changes: 5 additions & 0 deletions openeo_driver/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List, Union, Optional, Dict, Any, Tuple, Sequence

import geopandas as gpd
import numpy as np
import pyproj
import shapely.geometry
import shapely.geometry.base
Expand Down Expand Up @@ -288,6 +289,10 @@ def get_xarray_cube_basics(self) -> Tuple[tuple, dict]:
coords = {self.DIM_GEOMETRIES: self._geometries.index.to_list()}
return dims, coords

def __eq__(self, other):
return (isinstance(other, DriverVectorCube)
and np.array_equal(self._as_geopandas_df().values, other._as_geopandas_df().values))


class DriverMlModel:
"""Base class for driver-side 'ml-model' data structures"""
Expand Down
44 changes: 4 additions & 40 deletions openeo_driver/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,45 +26,6 @@
_log = logging.getLogger(__name__)


class EvalEnvEncoder(JSONEncoder):
"""
A custom json encoder in support of the __hash__ function. Does not aim to provide a completely representative json encoding.
"""
def default(self, o):
try:
iterable = iter(o)
except TypeError:
pass
else:
return list(iterable)

if isinstance(o,BaseGeometry):
return mapping(o)

from openeo_driver.backend import OpenEoBackendImplementation
from openeo_driver.dry_run import DryRunDataTracer
if isinstance(o,OpenEoBackendImplementation) or isinstance(o,DryRunDataTracer):
return str(o.__class__.__name__)

from openeo_driver.datacube import DriverDataCube
if isinstance(o,DriverDataCube):
return str(o)

from openeo_driver.users import User
if isinstance(o,User):
return o.user_id

if isinstance(o, Enum):
return o.value

from openeo_driver.delayed_vector import DelayedVector
if isinstance(o, DelayedVector):
return o.path

# Let the base class default method raise the TypeError
return JSONEncoder.default(self, o)


class EvalEnv:
"""
Process graph evaluation environment: key-value container for keeping track
Expand Down Expand Up @@ -131,7 +92,10 @@ def __str__(self):
return str(self.as_dict())

def __hash__(self) -> int:
return hash(json.dumps(self.as_dict(), sort_keys=True, cls=EvalEnvEncoder))
return 0 # poorly hashable but load_collection's lru_cache is small anyway

def __eq__(self, other) -> bool:
return isinstance(other, EvalEnv) and self.as_dict() == other.as_dict()

@property
def backend_implementation(self) -> 'OpenEoBackendImplementation':
Expand Down

0 comments on commit dd05a21

Please sign in to comment.