Skip to content

Commit

Permalink
Merge pull request #153 from pnuu/bugfix-uint8-stretch
Browse files Browse the repository at this point in the history
Fix stretching integer data
  • Loading branch information
mraspaud authored Nov 24, 2023
2 parents dd16cdf + 034d659 commit 5b9227b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
19 changes: 19 additions & 0 deletions trollimage/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,6 +1631,25 @@ def test_linear_stretch_does_not_affect_alpha(self, dtype):

np.testing.assert_allclose(img.data.values, res, atol=1.e-6)

@pytest.mark.parametrize(("dtype", "max_val", "exp_min", "exp_max"),
((np.uint8, 255, -0.005358012691140175, 1.0053772069513798),
(np.int8, 127, -0.004926108196377754, 1.0058689523488282),
(np.uint16, 65535, -0.005050825305515899, 1.005050893505104),
(np.int16, 32767, -0.005052744992717635, 1.0050527782880818),
(np.uint32, 4294967295, -0.005050505077517274, 1.0050505395923495),
(np.int32, 2147483647, -0.00505050499355784, 1.0050505395923495),
(int, 2147483647, -0.00505050499355784, 1.0050505395923495),
))
def test_linear_stretch_integers(self, dtype, max_val, exp_min, exp_max):
"""Test linear stretch with low-bit unsigned integer data."""
arr = np.linspace(0, max_val, num=75, dtype=dtype).reshape(5, 5, 3)
data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'],
coords={'bands': ['R', 'G', 'B']})
img = xrimage.XRImage(data)
img.stretch_linear()
assert img.data.values.min() == pytest.approx(exp_min)
assert img.data.values.max() == pytest.approx(exp_max)

@pytest.mark.parametrize("dtype", (np.float32, np.float64, float))
def test_histogram_stretch(self, dtype):
"""Test histogram stretching."""
Expand Down
8 changes: 7 additions & 1 deletion trollimage/xrimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,7 @@ def crude_stretch(self, min_stretch=None, max_stretch=None):

attrs = self.data.attrs
offset = -min_stretch * scale_factor

self.data = np.multiply(self.data, scale_factor, dtype=scale_factor.dtype) + offset
self.data.attrs = attrs
self.data.attrs.setdefault('enhancement_history', []).append({'scale': scale_factor,
Expand All @@ -1095,8 +1096,13 @@ def _check_stretch_value(self, val, kind='min'):
if isinstance(val, (list, tuple)):
val = self.xrify_tuples(val)

dtype = self.data.dtype
if dtype in (np.uint8, np.int8, np.uint16, np.int16):
dtype = np.dtype(np.float32)
elif np.issubdtype(dtype, np.integer) or isinstance(dtype, int):
dtype = np.dtype(np.float64)
try:
val = val.astype(self.data.dtype)
val = val.astype(dtype)
except AttributeError:
val = self.data.dtype.type(val)

Expand Down

0 comments on commit 5b9227b

Please sign in to comment.