diff --git a/trollimage/tests/test_image.py b/trollimage/tests/test_image.py index f49c166a..1f8820cf 100644 --- a/trollimage/tests/test_image.py +++ b/trollimage/tests/test_image.py @@ -1631,18 +1631,24 @@ def test_linear_stretch_does_not_affect_alpha(self, dtype): np.testing.assert_allclose(img.data.values, res, atol=1.e-6) - @pytest.mark.parametrize("dtype", (np.uint8, np.uint16, int)) - def test_linear_stretch_integer(self, dtype): - """Test linear stretch with integer data.""" - arr = np.arange(75, dtype=dtype).reshape(5, 5, 3) - arr[4, 4, :] = 255 + @pytest.mark.parametrize(("dtype", "max_val", "exp_min", "exp_max"), + ((np.uint8, 255, -0.005358012691140175, 1.0053772069513798), + (np.int8, 127, -0.004926108196377754, 1.0058689523488282), + (np.uint16, 65535, -0.005050825305515899, 1.005050893505104), + (np.int16, 32767, -0.005052744992717635, 1.0050527782880818), + (np.uint32, 4294967295, -0.005050505077517274, 1.0050505395923495), + (np.int32, 2147483647, -0.00505050499355784, 1.0050505395923495), + (int, 2147483647, -0.00505050499355784, 1.0050505395923495), + )) + def test_linear_stretch_integers(self, dtype, max_val, exp_min, exp_max): + """Test linear stretch with low-bit unsigned integer data.""" + arr = np.linspace(0, max_val, num=75, dtype=dtype).reshape(5, 5, 3) data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'], coords={'bands': ['R', 'G', 'B']}) img = xrimage.XRImage(data) img.stretch_linear() - - assert img.data.values.min() == pytest.approx(-0.0015614156835530892) - assert img.data.values.max() == pytest.approx(1.0960743801652901) + assert img.data.values.min() == pytest.approx(exp_min) + assert img.data.values.max() == pytest.approx(exp_max) @pytest.mark.parametrize("dtype", (np.float32, np.float64, float)) def test_histogram_stretch(self, dtype): diff --git a/trollimage/xrimage.py b/trollimage/xrimage.py index 70a87080..da0b28f8 100644 --- a/trollimage/xrimage.py +++ b/trollimage/xrimage.py @@ -1097,8 +1097,10 @@ def _check_stretch_value(self, val, kind='min'): val = self.xrify_tuples(val) dtype = self.data.dtype - if np.issubdtype(dtype, np.integer): + if isinstance(dtype, (np.uint8, np.int8, np.uint16, np.int16)): dtype = np.dtype(np.float32) + elif np.issubdtype(dtype, np.integer) or isinstance(dtype, int): + dtype = np.dtype(np.float64) try: val = val.astype(dtype) except AttributeError: