Skip to content

Commit

Permalink
Merge pull request #36 from brews/flake8_cleanup
Browse files Browse the repository at this point in the history
Bug fixes and cleanup to pass flake8 in CI
  • Loading branch information
brews authored Jun 28, 2022
2 parents ecdd7db + 6867555 commit c01ace2
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 255 deletions.
4 changes: 2 additions & 2 deletions climate_toolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
"""Top-level package for climate_toolbox."""

__author__ = """Justin Simcock"""
__email__ = '[email protected]'
__version__ = '0.1.5'
__email__ = "[email protected]"
__version__ = "0.1.5"
76 changes: 26 additions & 50 deletions climate_toolbox/aggregations/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,22 @@ def _reindex_spatial_data_to_regions(ds, df):
"""

# use vectorized indexing in xarray >= 0.10
if LooseVersion(xr.__version__) > LooseVersion('0.9.999'):
if LooseVersion(xr.__version__) > LooseVersion("0.9.999"):

lon_indexer = xr.DataArray(df.lon.values, dims=('reshape_index', ))
lat_indexer = xr.DataArray(df.lat.values, dims=('reshape_index', ))
lon_indexer = xr.DataArray(df.lon.values, dims=("reshape_index",))
lat_indexer = xr.DataArray(df.lat.values, dims=("reshape_index",))

return ds.sel(lon=lon_indexer, lat=lat_indexer)

else:
res = ds.sel_points(
'reshape_index',
lat=df.lat.values,
lon=df.lon.values)
res = ds.sel_points("reshape_index", lat=df.lat.values, lon=df.lon.values)

return res


def _aggregate_reindexed_data_to_regions(
ds,
variable,
aggwt,
agglev,
weights,
backup_aggwt='areawt'):
ds, variable, aggwt, agglev, weights, backup_aggwt="areawt"
):
"""
Performs weighted avg for climate variable by region
Expand Down Expand Up @@ -69,39 +62,29 @@ def _aggregate_reindexed_data_to_regions(
"""

ds.coords[agglev] = xr.DataArray(
weights[agglev].values,
dims={'reshape_index': weights.index.values})
weights[agglev].values, dims={"reshape_index": weights.index.values}
)

# format weights
ds[aggwt] = xr.DataArray(
weights[aggwt].values,
dims={'reshape_index': weights.index.values})

ds[aggwt] = (
ds[aggwt]
.where(ds[aggwt] > 0)
.fillna(weights[backup_aggwt].values))

weighted = xr.Dataset({
variable: (
(
(ds[variable]*ds[aggwt])
.groupby(agglev)
.sum(dim='reshape_index')) /
(
ds[aggwt]
.groupby(agglev)
.sum(dim='reshape_index')))})
weights[aggwt].values, dims={"reshape_index": weights.index.values}
)

ds[aggwt] = ds[aggwt].where(ds[aggwt] > 0).fillna(weights[backup_aggwt].values)

weighted = xr.Dataset(
{
variable: (
((ds[variable] * ds[aggwt]).groupby(agglev).sum(dim="reshape_index"))
/ (ds[aggwt].groupby(agglev).sum(dim="reshape_index"))
)
}
)

return weighted


def weighted_aggregate_grid_to_regions(
ds,
variable,
aggwt,
agglev,
weights=None):
def weighted_aggregate_grid_to_regions(ds, variable, aggwt, agglev, weights=None):
"""
Computes the weighted reshape of gridded data
Expand Down Expand Up @@ -136,12 +119,7 @@ def weighted_aggregate_grid_to_regions(
weights = prepare_spatial_weights_data()

ds = _reindex_spatial_data_to_regions(ds, weights)
ds = _aggregate_reindexed_data_to_regions(
ds,
variable,
aggwt,
agglev,
weights)
ds = _aggregate_reindexed_data_to_regions(ds, variable, aggwt, agglev, weights)

return ds

Expand All @@ -163,14 +141,12 @@ def prepare_spatial_weights_data(weights_file):
df = pd.read_csv(weights_file)

# Re-label out-of-bounds pixel centers
df.set_value((df['pix_cent_x'] == 180.125), 'pix_cent_x', -179.875)
df.set_value((df["pix_cent_x"] == 180.125), "pix_cent_x", -179.875)

# probably totally unnecessary
df.drop_duplicates()
df.index.names = ['reshape_index']
df.index.names = ["reshape_index"]

df.rename(
columns={'pix_cent_x': 'lon', 'pix_cent_y': 'lat'},
inplace=True)
df.rename(columns={"pix_cent_x": "lon", "pix_cent_y": "lat"}, inplace=True)

return df
1 change: 0 additions & 1 deletion climate_toolbox/climate_toolbox.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
"""
This file describes the process for computing weighted climate data
"""

27 changes: 16 additions & 11 deletions climate_toolbox/geo/distance.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import numpy as np

# model major (km) minor (km) flattening
ELLIPSOIDS = {'WGS-84': (6378.137, 6356.7523142, 1 / 298.257223563),
'GRS-80': (6378.137, 6356.7523141, 1 / 298.257222101),
'Airy (1830)': (6377.563396, 6356.256909, 1 / 299.3249646),
'Intl 1924': (6378.388, 6356.911946, 1 / 297.0),
'Clarke (1880)': (6378.249145, 6356.51486955, 1 / 293.465),
'GRS-67': (6378.1600, 6356.774719, 1 / 298.25),
}
ELLIPSOIDS = {
"WGS-84": (6378.137, 6356.7523142, 1 / 298.257223563),
"GRS-80": (6378.137, 6356.7523141, 1 / 298.257222101),
"Airy (1830)": (6377.563396, 6356.256909, 1 / 299.3249646),
"Intl 1924": (6378.388, 6356.911946, 1 / 297.0),
"Clarke (1880)": (6378.249145, 6356.51486955, 1 / 293.465),
"GRS-67": (6378.1600, 6356.774719, 1 / 298.25),
}


EARTH_RADIUS = 6371.009
Expand Down Expand Up @@ -57,9 +59,12 @@ def great_circle(ax, ay, bx, by, radius=EARTH_RADIUS):
delta_lng = lng2 - lng1
cos_delta_lng, sin_delta_lng = np.cos(delta_lng), np.sin(delta_lng)

d = np.arctan2(np.sqrt((cos_lat2 * sin_delta_lng) ** 2 +
(cos_lat1 * sin_lat2 -
sin_lat1 * cos_lat2 * cos_delta_lng) ** 2),
sin_lat1 * sin_lat2 + cos_lat1 * cos_lat2 * cos_delta_lng)
d = np.arctan2(
np.sqrt(
(cos_lat2 * sin_delta_lng) ** 2
+ (cos_lat1 * sin_lat2 - sin_lat1 * cos_lat2 * cos_delta_lng) ** 2
),
sin_lat1 * sin_lat2 + cos_lat1 * cos_lat2 * cos_delta_lng,
)

return radius * d
17 changes: 6 additions & 11 deletions climate_toolbox/io/io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import xarray as xr

from climate_toolbox.utils.utils import *
from climate_toolbox.utils.utils import rename_coords_to_lon_and_lat, convert_lons_split


def standardize_climate_data(ds):
Expand All @@ -19,12 +19,12 @@ def standardize_climate_data(ds):
"""

ds = rename_coords_to_lon_and_lat(ds)
ds = convert_lons_split(ds, lon_name='lon')
ds = convert_lons_split(ds, lon_name="lon")

return ds


def load_bcsd(fp, varname, lon_name='lon', broadcast_dims=('time',)):
def load_bcsd(fp, varname, lon_name="lon", broadcast_dims=("time",)):
"""
Read and prepare climate data
Expand All @@ -48,11 +48,7 @@ def load_bcsd(fp, varname, lon_name='lon', broadcast_dims=('time',)):
xr.Dataset
xarray dataset loaded into memory
"""

if lon_name is not None:
lon_names = [lon_name]

if hasattr(fp, 'sel_points'):
if hasattr(fp, "sel_points"):
ds = fp

else:
Expand All @@ -62,10 +58,9 @@ def load_bcsd(fp, varname, lon_name='lon', broadcast_dims=('time',)):
return standardize_climate_data(ds)


def load_gmfd(fp, varname, lon_name='lon', broadcast_dims=('time',)):
def load_gmfd(fp, varname, lon_name="lon", broadcast_dims=("time",)):
pass


def load_best(fp, varname, lon_name='lon', broadcast_dims=('time',)):
def load_best(fp, varname, lon_name="lon", broadcast_dims=("time",)):
pass

81 changes: 42 additions & 39 deletions climate_toolbox/transformations/transformations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import xarray as xr
import numpy as np

from climate_toolbox.utils.utils import \
remove_leap_days, convert_kelvin_to_celsius
from climate_toolbox.utils.utils import remove_leap_days, convert_kelvin_to_celsius


def snyder_edd(tasmin, tasmax, threshold):
Expand Down Expand Up @@ -67,9 +66,9 @@ def snyder_edd(tasmin, tasmax, threshold):
assert not (tasmax < tasmin).any(), "values encountered where tasmin > tasmax"

# compute useful quantities for use in the transformation
snyder_mean = ((tasmax + tasmin)/2)
snyder_width = ((tasmax - tasmin)/2)
snyder_theta = xr.ufuncs.arcsin((threshold - snyder_mean)/snyder_width)
snyder_mean = (tasmax + tasmin) / 2
snyder_width = (tasmax - tasmin) / 2
snyder_theta = xr.ufuncs.arcsin((threshold - snyder_mean) / snyder_width)

# the trasnformation is computed using numpy arrays, taking advantage of
# numpy's second where clause. Note that in the current dev build of
Expand All @@ -79,13 +78,17 @@ def snyder_edd(tasmin, tasmax, threshold):
tasmin < threshold,
xr.where(
tasmax > threshold,
((snyder_mean - threshold) * (np.pi/2 - snyder_theta)
+ (snyder_width * np.cos(snyder_theta))) / np.pi,
0),
snyder_mean - threshold)

res.attrs['units'] = (
'degreedays_{}{}'.format(threshold, tasmax.attrs['units']))
(
(snyder_mean - threshold) * (np.pi / 2 - snyder_theta)
+ (snyder_width * np.cos(snyder_theta))
)
/ np.pi,
0,
),
snyder_mean - threshold,
)

res.attrs["units"] = "degreedays_{}{}".format(threshold, tasmax.attrs["units"])

return res

Expand Down Expand Up @@ -133,19 +136,19 @@ def snyder_gdd(tasmin, tasmax, threshold_low, threshold_high):
# Check for unit agreement
assert tasmin.units == tasmax.units

res = (
snyder_edd(tasmin, tasmax, threshold_low)
- snyder_edd(tasmin, tasmax, threshold_high))

res.attrs['units'] = (
'degreedays_{}-{}{}'.format(threshold_low, threshold_high, tasmax.attrs['units']))
res = snyder_edd(tasmin, tasmax, threshold_low) - snyder_edd(
tasmin, tasmax, threshold_high
)

res.attrs["units"] = "degreedays_{}-{}{}".format(
threshold_low, threshold_high, tasmax.attrs["units"]
)

return res


def validate_edd_snyder_agriculture(ds, thresholds):
msg_null = 'hierid dims do not match 24378'
msg_null = "hierid dims do not match 24378"

assert ds.hierid.shape == (24378,), msg_null

Expand All @@ -164,48 +167,48 @@ def tas_poly(ds, power, varname):

powername = ordinal(power)

description = ('''
description = (
"""
Daily average temperature (degrees C){raised}
Leap years are removed before counting days (uses a 365 day
calendar).
'''.format(
raised='' if power == 1 else (
' raised to the {powername} power'
.format(powername=powername)))).strip()
""".format(
raised=""
if power == 1
else (" raised to the {powername} power".format(powername=powername))
)
).strip()

ds1 = xr.Dataset()

# remove leap years
ds = remove_leap_days(ds)

# do transformation
ds1[varname] = (ds.tas - 273.15)**power
ds1[varname] = (ds.tas - 273.15) ** power

# Replace datetime64[ns] 'time' with YYYYDDD int 'day'
if ds.dims['time'] > 365:
if ds.dims["time"] > 365:
raise ValueError

ds1.coords['day'] = ds['time.year']*1000 + np.arange(1, len(ds.time)+1)
ds1 = ds1.swap_dims({'time': 'day'})
ds1 = ds1.drop('time')
ds1.coords["day"] = ds["time.year"] * 1000 + np.arange(1, len(ds.time) + 1)
ds1 = ds1.swap_dims({"time": "day"})
ds1 = ds1.drop("time")

ds1 = ds1.rename({'day': 'time'})
ds1 = ds1.rename({"day": "time"})

# document variable
ds1[varname].attrs['units'] = (
'C^{}'.format(power) if power > 1 else 'C')
ds1[varname].attrs["units"] = "C^{}".format(power) if power > 1 else "C"

ds1[varname].attrs['long_title'] = description.splitlines()[0]
ds1[varname].attrs['description'] = description
ds1[varname].attrs['variable'] = varname
ds1[varname].attrs["long_title"] = description.splitlines()[0]
ds1[varname].attrs["description"] = description
ds1[varname].attrs["variable"] = varname

return ds1


def ordinal(n):
""" Converts numbers into ordinal strings """
"""Converts numbers into ordinal strings"""

return (
"%d%s" %
(n, "tsnrhtdd"[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4]))
return "%d%s" % (n, "tsnrhtdd"[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10 :: 4])
Loading

0 comments on commit c01ace2

Please sign in to comment.