-
Notifications
You must be signed in to change notification settings - Fork 154
How to add dask support
See the steps below for how dask support can be added to WRF-Python in v2.0.
Let's choose a simple example, the 'tv' diagnostic from getvar, in particular the 'getter' function found in g_temp.py. The code is below (note: ignoring the metadata and units stuff here):
def get_tv(wrfin, timeidx=0, method="cat", squeeze=True,
cache=None, meta=True, _key=None,
units="K"):
varnames = ("T", "P", "PB", "QVAPOR")
ncvars = extract_vars(wrfin, timeidx, varnames, method, squeeze, cache,
meta=False, _key=_key)
t = ncvars["T"]
p = ncvars["P"]
pb = ncvars["PB"]
qv = ncvars["QVAPOR"]
full_t = t + Constants.T_BASE
full_p = p + pb
tk = _tk(full_p, full_t)
tv = _tv(tk, qv)
return tv
In the code above, we extract a few variables from a WRF file, compute pressure and potential temperature, then tk (temperature in kelvin) and tv (virtual temperature). Below, let's show how this could be rewritten using xarray and dask, which can be easily wrapped in to an xarray extension.
-
Create thin wrappers
The _tv and _tk code in extension.py calls a Fortran routine and performs several common operations via wrapt decorators. Unfortunately, wrapt decorators don't serialize, so we need to create a thin wrappers around them.
Also, in order to create the dask tasks, we need functions to pass to the dask map_blocks routine, so we'll also need wrappers around the "base + perturbation" operations above.
Let's start with the _tv and _tk wrappers. For these functions, since OpenMP is already supported at the Fortran level, let's take an additional argument to set the number of OpenMP threads to use. Note that dask can do what OpenMP does, so this is entirely optional, but if you wanted to use dask tasks with OpenMP for the low level computation, this is one way to do it.
def tk_wrap(pressure, theta, omp_threads=1): from wrf.extension import _tk, omp_set_num_threads omp_set_num_threads(omp_threads) result = _tk(pressure, theta) return result
def tv_wrap(temp_k, qvapor, omp_threads=1): from wrf.extension import _tv, omp_set_num_threads omp_set_num_threads(omp_threads) result = _tv(temp_k, qvapor) return result
Next, let's make a thin wrapper for the "perturbation + base" operation:
def pert_add(base, perturbation): return base + perturbation
-
Create a 'getter' function for the 'tv' diagnostic
The original getter method above takes several common arguments (wrfin, timeidx, method, squeeze, etc.). Since much of WRF-Python 1.x implements things xarray does, a lot of this can be gutted in WRF-Python 2.x in favor of xarray. So, this getter function only needs to take an xarray Dataset argument and the OpenMP number of threads argument (only if you want users to control OpenMP threads).
For this example, we're going to assume dask is installed, but in a real implementation, the xarray.DataArray.data attribute might return a numpy array so it should be prepared for that. Here, we're only trying to show how to make dask work.
Here is the code:
from wrf import Constants from dask.array import map_blocks def tv_getter(ds, omp_threads=1): t = ds["T"].data p = ds["P"].data pb = ds["PB"].data qv = ds["QVAPOR"].data full_t = map_blocks(pert_add, Constants.T_BASE, t, omp_threads, dtype=t.dtype) full_p = map_blocks(pert_add, pb, p, omp_threads, dtype=p.dtype) tk = map_blocks(tk_wrap, full_p, full_t, omp_threads, dtype=p.dtype) tv = map_blocks(tv_wrap, tk, qv) return tv
Note that the above is returning a dask array with no metadata, so nothing has actually happened other than building the task graph for dask. If you want to actually compute something, you need to call compute() on the returned dask object. As for adding metadata, that exercise is beyond the scope of this tutorial and is left as an exercise to whoever implements WRF-Python 2.x.
Let's test the getter function below (assuming your getter function has already been imported):
import xarray # Setting ds = xarray.open_mfdataset("/path/to/wrf_vortex_multi/moving_nest/wrfout_d02*", parallel=True) tv = tv_getter(ds, omp_threads=1) # Now actually compute tv (note: result is a numpy array) tv_result = tv.compute()
-
Create an xarray extension if you want an object oriented API
If you want your API to work on the Dataset object itself, rather than in a separate function, then xarray has an easy way to adding your extensions, which can be found here: http://xarray.pydata.org/en/stable/internals.html#extending-xarray
Since most of the WRF routines work on Datasets, we're going to create a Dataset extension. Again, the result of this will be a dask array, so handling when the actual computation is performed and metadata applied will be left as an exercise to whoever implements WRF-Python 2.x, but this should illustrate the basic concepts.
First, let's create the xarray extension class, which will add a new 'wrf' attribute to the Dataset.
import xarray _FUNC_MAP = {'tv' : tv_getter} @xarray.register_dataset_accessor('wrf') class WRFDatasetExtension(object): def __init__(self, xarray_obj): self._obj = xarray_obj def getvar(self, product, omp_threads=1, **kwargs): return _FUNC_MAP[product](self._obj, omp_threads, **kwargs)
Now if you want to use this:
ds = xarray.open_mfdataset("/path/to/wrf_vortex_multi/moving_nest/wrfout_d02*", parallel=True) tv = ds.wrf.getvar("tv", omp_threads=1) # Compute the result result = tv.compute()