diff --git a/trollimage/tests/test_image.py b/trollimage/tests/test_image.py index 3421efe8..1f8820cf 100644 --- a/trollimage/tests/test_image.py +++ b/trollimage/tests/test_image.py @@ -1631,6 +1631,25 @@ def test_linear_stretch_does_not_affect_alpha(self, dtype): np.testing.assert_allclose(img.data.values, res, atol=1.e-6) + @pytest.mark.parametrize(("dtype", "max_val", "exp_min", "exp_max"), + ((np.uint8, 255, -0.005358012691140175, 1.0053772069513798), + (np.int8, 127, -0.004926108196377754, 1.0058689523488282), + (np.uint16, 65535, -0.005050825305515899, 1.005050893505104), + (np.int16, 32767, -0.005052744992717635, 1.0050527782880818), + (np.uint32, 4294967295, -0.005050505077517274, 1.0050505395923495), + (np.int32, 2147483647, -0.00505050499355784, 1.0050505395923495), + (int, 2147483647, -0.00505050499355784, 1.0050505395923495), + )) + def test_linear_stretch_integers(self, dtype, max_val, exp_min, exp_max): + """Test linear stretch with low-bit unsigned integer data.""" + arr = np.linspace(0, max_val, num=75, dtype=dtype).reshape(5, 5, 3) + data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'], + coords={'bands': ['R', 'G', 'B']}) + img = xrimage.XRImage(data) + img.stretch_linear() + assert img.data.values.min() == pytest.approx(exp_min) + assert img.data.values.max() == pytest.approx(exp_max) + @pytest.mark.parametrize("dtype", (np.float32, np.float64, float)) def test_histogram_stretch(self, dtype): """Test histogram stretching.""" diff --git a/trollimage/xrimage.py b/trollimage/xrimage.py index 515925e4..be218c03 100644 --- a/trollimage/xrimage.py +++ b/trollimage/xrimage.py @@ -1082,6 +1082,7 @@ def crude_stretch(self, min_stretch=None, max_stretch=None): attrs = self.data.attrs offset = -min_stretch * scale_factor + self.data = np.multiply(self.data, scale_factor, dtype=scale_factor.dtype) + offset self.data.attrs = attrs self.data.attrs.setdefault('enhancement_history', []).append({'scale': scale_factor, @@ -1095,8 +1096,13 @@ def _check_stretch_value(self, val, kind='min'): if isinstance(val, (list, tuple)): val = self.xrify_tuples(val) + dtype = self.data.dtype + if dtype in (np.uint8, np.int8, np.uint16, np.int16): + dtype = np.dtype(np.float32) + elif np.issubdtype(dtype, np.integer) or isinstance(dtype, int): + dtype = np.dtype(np.float64) try: - val = val.astype(self.data.dtype) + val = val.astype(dtype) except AttributeError: val = self.data.dtype.type(val)