diff --git a/trollimage/tests/test_image.py b/trollimage/tests/test_image.py index 08f051b9..88144a32 100644 --- a/trollimage/tests/test_image.py +++ b/trollimage/tests/test_image.py @@ -1456,7 +1456,7 @@ def test_gamma(self): def test_crude_stretch(self): """Check crude stretching.""" - arr = np.arange(75).reshape(5, 5, 3) + arr = np.arange(75, dtype=np.float32).reshape(5, 5, 3) data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'], coords={'bands': ['R', 'G', 'B']}) img = xrimage.XRImage(data) @@ -1467,19 +1467,34 @@ def test_crude_stretch(self): enhs = img.data.attrs['enhancement_history'][0] scale_expected = np.array([0.01388889, 0.01388889, 0.01388889]) offset_expected = np.array([0., -0.01388889, -0.02777778]) + assert img.data.dtype == np.float32 np.testing.assert_allclose(enhs['scale'].values, scale_expected) np.testing.assert_allclose(enhs['offset'].values, offset_expected) - np.testing.assert_allclose(red, arr[:, :, 0] / 72.) - np.testing.assert_allclose(green, (arr[:, :, 1] - 1.) / (73. - 1.)) - np.testing.assert_allclose(blue, (arr[:, :, 2] - 2.) / (74. - 2.)) - + expected_red = arr[:, :, 0] / 72. + np.testing.assert_allclose(red, expected_red.astype(np.float32), rtol=1e-6) + expected_green = (arr[:, :, 1] - 1.) / (73. - 1.) + np.testing.assert_allclose(green, expected_green.astype(np.float32), rtol=1e-6) + expected_blue = (arr[:, :, 2] - 2.) / (74. - 2.) + np.testing.assert_allclose(blue, expected_blue.astype(np.float32), rtol=1e-6) + + def test_crude_stretch_with_limits(self): arr = np.arange(75).reshape(5, 5, 3).astype(float) data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'], coords={'bands': ['R', 'G', 'B']}) img = xrimage.XRImage(data) img.crude_stretch(0, 74) + assert img.data.dtype == float np.testing.assert_allclose(img.data.values, arr / 74.) + def test_crude_stretch_integer_data(self): + arr = np.arange(75, dtype=int).reshape(5, 5, 3) + data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'], + coords={'bands': ['R', 'G', 'B']}) + img = xrimage.XRImage(data) + img.crude_stretch(0, 74) + assert img.data.dtype == np.float32 + np.testing.assert_allclose(img.data.values, arr.astype(np.float32) / 74., rtol=1e-6) + def test_invert(self): """Check inversion of the image.""" arr = np.arange(75).reshape(5, 5, 3) / 75. @@ -2325,10 +2340,10 @@ class TestXRImageSaveScaleOffset: def setup_method(self) -> None: """Set up the test case.""" from trollimage import xrimage - data = xr.DataArray(np.arange(25).reshape(5, 5, 1), dims=[ + data = xr.DataArray(np.arange(25, dtype=np.float32).reshape(5, 5, 1), dims=[ 'y', 'x', 'bands'], coords={'bands': ['L']}) self.img = xrimage.XRImage(data) - rgb_data = xr.DataArray(np.arange(3 * 25).reshape(5, 5, 3), dims=[ + rgb_data = xr.DataArray(np.arange(3 * 25, dtype=np.float32).reshape(5, 5, 3), dims=[ 'y', 'x', 'bands'], coords={'bands': ['R', 'G', 'B']}) self.rgb_img = xrimage.XRImage(rgb_data) diff --git a/trollimage/xrimage.py b/trollimage/xrimage.py index d287adbf..4e7e6440 100644 --- a/trollimage/xrimage.py +++ b/trollimage/xrimage.py @@ -1066,31 +1066,47 @@ def crude_stretch(self, min_stretch=None, max_stretch=None): normalizes to the [0,1] range. """ - if min_stretch is None: - non_band_dims = tuple(x for x in self.data.dims if x != 'bands') - min_stretch = self.data.min(dim=non_band_dims) - if max_stretch is None: + min_stretch = self._check_stretch_value(min_stretch, kind='min') + max_stretch = self._check_stretch_value(max_stretch, kind='max') + scale_factor = self._get_scale_factor(min_stretch, max_stretch) + + attrs = self.data.attrs + offset = -min_stretch * scale_factor + self.data = np.multiply(self.data, scale_factor, dtype=scale_factor.dtype) + offset + self.data.attrs = attrs + self.data.attrs.setdefault('enhancement_history', []).append({'scale': scale_factor, + 'offset': offset}) + + def _check_stretch_value(self, val, kind='min'): + if val is None: non_band_dims = tuple(x for x in self.data.dims if x != 'bands') - max_stretch = self.data.max(dim=non_band_dims) + val = getattr(self.data, kind)(dim=non_band_dims) - if isinstance(min_stretch, (list, tuple)): - min_stretch = self.xrify_tuples(min_stretch) - if isinstance(max_stretch, (list, tuple)): - max_stretch = self.xrify_tuples(max_stretch) + if isinstance(val, (list, tuple)): + val = self.xrify_tuples(val) + try: + val = val.astype(self.data.dtype) + except AttributeError: + val = self.data.dtype.type(val) + + return val + + def _get_scale_factor(self, min_stretch, max_stretch): delta = (max_stretch - min_stretch) + dtype = self._infer_scale_factor_dtype() if isinstance(delta, xr.DataArray): # fillna if delta is NaN - scale_factor = (1.0 / delta).fillna(0) + scale_factor = (1.0 / delta).fillna(0).astype(dtype) else: - scale_factor = 1.0 / delta - attrs = self.data.attrs - offset = -min_stretch * scale_factor - self.data *= scale_factor - self.data += offset - self.data.attrs = attrs - self.data.attrs.setdefault('enhancement_history', []).append({'scale': scale_factor, - 'offset': offset}) + scale_factor = np.array(1.0 / delta, dtype=dtype) + + return scale_factor + + def _infer_scale_factor_dtype(self): + if np.issubdtype(self.data.dtype, np.integer): + return np.float32 + return self.data.dtype def stretch_hist_equalize(self, approximate=False): """Stretch the current image's colors through histogram equalization.