Skip to content

Commit

Permalink
Merge pull request #150 from pnuu/keep-stretch-dtype
Browse files Browse the repository at this point in the history
Keep the original dtype of the data when stretching
  • Loading branch information
pnuu authored Nov 20, 2023
2 parents 6617ff7 + afc693a commit 9564cba
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 25 deletions.
29 changes: 22 additions & 7 deletions trollimage/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,7 +1456,7 @@ def test_gamma(self):

def test_crude_stretch(self):
"""Check crude stretching."""
arr = np.arange(75).reshape(5, 5, 3)
arr = np.arange(75, dtype=np.float32).reshape(5, 5, 3)
data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'],
coords={'bands': ['R', 'G', 'B']})
img = xrimage.XRImage(data)
Expand All @@ -1467,19 +1467,34 @@ def test_crude_stretch(self):
enhs = img.data.attrs['enhancement_history'][0]
scale_expected = np.array([0.01388889, 0.01388889, 0.01388889])
offset_expected = np.array([0., -0.01388889, -0.02777778])
assert img.data.dtype == np.float32
np.testing.assert_allclose(enhs['scale'].values, scale_expected)
np.testing.assert_allclose(enhs['offset'].values, offset_expected)
np.testing.assert_allclose(red, arr[:, :, 0] / 72.)
np.testing.assert_allclose(green, (arr[:, :, 1] - 1.) / (73. - 1.))
np.testing.assert_allclose(blue, (arr[:, :, 2] - 2.) / (74. - 2.))

expected_red = arr[:, :, 0] / 72.
np.testing.assert_allclose(red, expected_red.astype(np.float32), rtol=1e-6)
expected_green = (arr[:, :, 1] - 1.) / (73. - 1.)
np.testing.assert_allclose(green, expected_green.astype(np.float32), rtol=1e-6)
expected_blue = (arr[:, :, 2] - 2.) / (74. - 2.)
np.testing.assert_allclose(blue, expected_blue.astype(np.float32), rtol=1e-6)

def test_crude_stretch_with_limits(self):
arr = np.arange(75).reshape(5, 5, 3).astype(float)
data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'],
coords={'bands': ['R', 'G', 'B']})
img = xrimage.XRImage(data)
img.crude_stretch(0, 74)
assert img.data.dtype == float
np.testing.assert_allclose(img.data.values, arr / 74.)

def test_crude_stretch_integer_data(self):
arr = np.arange(75, dtype=int).reshape(5, 5, 3)
data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'],
coords={'bands': ['R', 'G', 'B']})
img = xrimage.XRImage(data)
img.crude_stretch(0, 74)
assert img.data.dtype == np.float32
np.testing.assert_allclose(img.data.values, arr.astype(np.float32) / 74., rtol=1e-6)

def test_invert(self):
"""Check inversion of the image."""
arr = np.arange(75).reshape(5, 5, 3) / 75.
Expand Down Expand Up @@ -2325,10 +2340,10 @@ class TestXRImageSaveScaleOffset:
def setup_method(self) -> None:
"""Set up the test case."""
from trollimage import xrimage
data = xr.DataArray(np.arange(25).reshape(5, 5, 1), dims=[
data = xr.DataArray(np.arange(25, dtype=np.float32).reshape(5, 5, 1), dims=[
'y', 'x', 'bands'], coords={'bands': ['L']})
self.img = xrimage.XRImage(data)
rgb_data = xr.DataArray(np.arange(3 * 25).reshape(5, 5, 3), dims=[
rgb_data = xr.DataArray(np.arange(3 * 25, dtype=np.float32).reshape(5, 5, 3), dims=[
'y', 'x', 'bands'], coords={'bands': ['R', 'G', 'B']})
self.rgb_img = xrimage.XRImage(rgb_data)

Expand Down
52 changes: 34 additions & 18 deletions trollimage/xrimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,31 +1066,47 @@ def crude_stretch(self, min_stretch=None, max_stretch=None):
normalizes to the [0,1] range.
"""
if min_stretch is None:
non_band_dims = tuple(x for x in self.data.dims if x != 'bands')
min_stretch = self.data.min(dim=non_band_dims)
if max_stretch is None:
min_stretch = self._check_stretch_value(min_stretch, kind='min')
max_stretch = self._check_stretch_value(max_stretch, kind='max')
scale_factor = self._get_scale_factor(min_stretch, max_stretch)

attrs = self.data.attrs
offset = -min_stretch * scale_factor
self.data = np.multiply(self.data, scale_factor, dtype=scale_factor.dtype) + offset
self.data.attrs = attrs
self.data.attrs.setdefault('enhancement_history', []).append({'scale': scale_factor,
'offset': offset})

def _check_stretch_value(self, val, kind='min'):
if val is None:
non_band_dims = tuple(x for x in self.data.dims if x != 'bands')
max_stretch = self.data.max(dim=non_band_dims)
val = getattr(self.data, kind)(dim=non_band_dims)

if isinstance(min_stretch, (list, tuple)):
min_stretch = self.xrify_tuples(min_stretch)
if isinstance(max_stretch, (list, tuple)):
max_stretch = self.xrify_tuples(max_stretch)
if isinstance(val, (list, tuple)):
val = self.xrify_tuples(val)

try:
val = val.astype(self.data.dtype)
except AttributeError:
val = self.data.dtype.type(val)

return val

def _get_scale_factor(self, min_stretch, max_stretch):
delta = (max_stretch - min_stretch)
dtype = self._infer_scale_factor_dtype()
if isinstance(delta, xr.DataArray):
# fillna if delta is NaN
scale_factor = (1.0 / delta).fillna(0)
scale_factor = (1.0 / delta).fillna(0).astype(dtype)
else:
scale_factor = 1.0 / delta
attrs = self.data.attrs
offset = -min_stretch * scale_factor
self.data *= scale_factor
self.data += offset
self.data.attrs = attrs
self.data.attrs.setdefault('enhancement_history', []).append({'scale': scale_factor,
'offset': offset})
scale_factor = np.array(1.0 / delta, dtype=dtype)

return scale_factor

def _infer_scale_factor_dtype(self):
if np.issubdtype(self.data.dtype, np.integer):
return np.float32
return self.data.dtype

def stretch_hist_equalize(self, approximate=False):
"""Stretch the current image's colors through histogram equalization.
Expand Down

0 comments on commit 9564cba

Please sign in to comment.