Skip to content

Commit

Permalink
Optimize loading behavior for GetMap (#110)
Browse files Browse the repository at this point in the history
* Change compute to load

* Add configuragle array limit

* Make fully configurable

* Thread through
  • Loading branch information
mpiannucci authored Jan 7, 2025
1 parent 112696a commit ae30539
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 14 deletions.
8 changes: 4 additions & 4 deletions xpublish_wms/grids/hycom.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def project(
da = self.mask(da)

# create 2 separate DataArrays where points lng>180 are put at the beginning of the array
mask_0 = xr.where(da.cf["longitude"] <= 180, 1, 0)
temp_da_0 = da.where(mask_0.compute() == 1, drop=True)
mask_0 = xr.where(da.cf["longitude"] <= 180, 1, 0).load()
temp_da_0 = da.where(mask_0 == 1, drop=True)
da_0 = xr.DataArray(
data=temp_da_0,
dims=temp_da_0.dims,
Expand All @@ -91,8 +91,8 @@ def project(
attrs=temp_da_0.attrs,
)

mask_1 = xr.where(da.cf["longitude"] > 180, 1, 0)
temp_da_1 = da.where(mask_1.compute() == 1, drop=True)
mask_1 = xr.where(da.cf["longitude"] > 180, 1, 0).load()
temp_da_1 = da.where(mask_1 == 1, drop=True)
temp_da_1.cf["longitude"][:] = temp_da_1.cf["longitude"][:] - 360
da_1 = xr.DataArray(
data=temp_da_1,
Expand Down
14 changes: 13 additions & 1 deletion xpublish_wms/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ class CfWmsPlugin(Plugin):
dataset_router_prefix: str = "/wms"
dataset_router_tags: List[str] = ["wms"]

# Limit for rendering arrays in get_map after subsetting to the requested
# bounding box. If the array is larger than this threshold, an error will be thrown.
# Default is 1e9 bytes (1 GB)
array_get_map_render_threshold_bytes: int = 1e9

@hookimpl
def dataset_router(self, deps: Dependencies) -> APIRouter:
"""Register dataset level router for WMS endpoints"""
Expand All @@ -54,6 +59,13 @@ def wms_root(
dataset: xr.Dataset = Depends(deps.dataset),
cache: cachey.Cache = Depends(deps.cache),
):
return wms_handler(request, wms_query, dataset, cache)
# TODO: Make threshold configurable
return wms_handler(
request,
wms_query,
dataset,
cache,
array_get_map_render_threshold_bytes=self.array_get_map_render_threshold_bytes,
)

return router
14 changes: 12 additions & 2 deletions xpublish_wms/wms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def wms_handler(
],
dataset: xr.Dataset,
cache: cachey.Cache,
array_get_map_render_threshold_bytes: int,
) -> Response:
query_params = lower_case_keys(request.query_params)
query_keys = list(query_params.keys())
Expand All @@ -52,9 +53,18 @@ def wms_handler(
if isinstance(query, WMSGetCapabilitiesQuery):
return get_capabilities(dataset, request, query)
elif isinstance(query, WMSGetMetadataQuery):
return get_metadata(dataset, cache, query, query_params)
return get_metadata(
dataset,
cache,
query,
query_params,
array_get_map_render_threshold_bytes=array_get_map_render_threshold_bytes,
)
elif isinstance(query, WMSGetMapQuery):
getmap_service = GetMap(cache=cache)
getmap_service = GetMap(
cache=cache,
array_render_threshold_bytes=array_get_map_render_threshold_bytes,
)
return getmap_service.get_map(dataset, query, query_params)
elif isinstance(query, WMSGetFeatureInfoQuery):
return get_feature_info(dataset, query, query_params)
Expand Down
25 changes: 20 additions & 5 deletions xpublish_wms/wms/get_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class GetMap:
BBOX_BUFFER = 0.18

cache: cachey.Cache
array_render_threshold_bytes: int

# Data selection
parameter: str
Expand All @@ -56,8 +57,9 @@ class GetMap:
colorscalerange: List[float]
autoscale: bool

def __init__(self, cache: cachey.Cache):
def __init__(self, cache: cachey.Cache, array_render_threshold_bytes: int):
self.cache = cache
self.array_render_threshold_bytes = array_render_threshold_bytes

def get_map(
self,
Expand Down Expand Up @@ -327,14 +329,27 @@ def render(

logger.info(f"WMS GetMap Projection time: {time.time() - projection_start}")

# Print the size of the da in megabytes
da_size = da.nbytes
if da_size > self.array_render_threshold_bytes:
logger.error(
f"DataArray size is {da_size:.2f} bytes, which is larger than the "
f"threshold of {self.array_render_threshold_bytes} bytes. "
f"Consider increasing the threshold in the plugin configuration.",
)
raise ValueError(
f"DataArray too large to render: threshold is {self.array_render_threshold_bytes} bytes, data is {da_size:.2f} bytes",
)
logger.info(f"WMS GetMap loading DataArray size: {da_size:.2f} bytes")

start_dask = time.time()

da = da.compute()
da = da.load()
if x is not None and y is not None:
x = x.compute()
y = y.compute()
x = x.load()
y = y.load()

logger.info(f"WMS GetMap dask compute: {time.time() - start_dask}")
logger.info(f"WMS GetMap load data: {time.time() - start_dask}")

if minmax_only:
try:
Expand Down
15 changes: 13 additions & 2 deletions xpublish_wms/wms/get_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def get_metadata(
cache: cachey.Cache,
query: WMSGetMetadataQuery,
query_params: dict,
array_get_map_render_threshold_bytes: int,
) -> Response:
"""
Return the WMS metadata for the dataset
Expand Down Expand Up @@ -45,7 +46,13 @@ def get_metadata(
da = ds[layer_name]
payload = get_timesteps(da, query)
elif metadata_type == "minmax":
payload = get_minmax(ds, cache, query, query_params)
payload = get_minmax(
ds,
cache,
query,
query_params,
array_get_map_render_threshold_bytes,
)
else:
raise HTTPException(
status_code=400,
Expand Down Expand Up @@ -90,6 +97,7 @@ def get_minmax(
cache: cachey.Cache,
query: WMSGetMetadataQuery,
query_params: dict,
array_get_map_render_threshold_bytes: int,
) -> dict:
"""
Returns the min and max range of values for a given layer in a given area
Expand All @@ -112,7 +120,10 @@ def get_minmax(
colorscalerange="nan,nan",
)

getmap = GetMap(cache=cache)
getmap = GetMap(
cache=cache,
array_render_threshold_bytes=array_get_map_render_threshold_bytes,
)
return getmap.get_minmax(ds, getmap_query, query_params, entire_layer)


Expand Down

0 comments on commit ae30539

Please sign in to comment.