From d34149cfcee2c8e323eaf1b88bc42a32f08402da Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Wed, 15 Nov 2023 15:37:30 +0200 Subject: [PATCH 1/7] Keep dtype when scaling the data --- trollimage/xrimage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trollimage/xrimage.py b/trollimage/xrimage.py index d287adbf..54833703 100644 --- a/trollimage/xrimage.py +++ b/trollimage/xrimage.py @@ -1086,8 +1086,8 @@ def crude_stretch(self, min_stretch=None, max_stretch=None): scale_factor = 1.0 / delta attrs = self.data.attrs offset = -min_stretch * scale_factor - self.data *= scale_factor - self.data += offset + self.data *= scale_factor.astype(self.data.dtype) + self.data += offset.astype(self.data.dtype) self.data.attrs = attrs self.data.attrs.setdefault('enhancement_history', []).append({'scale': scale_factor, 'offset': offset}) From f23bb0c976407ed31e713ed757b9df9a4a42799b Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Wed, 15 Nov 2023 16:21:40 +0200 Subject: [PATCH 2/7] Make sure stretching keeps the original dtype --- trollimage/tests/test_image.py | 12 ++++++++---- trollimage/xrimage.py | 8 ++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/trollimage/tests/test_image.py b/trollimage/tests/test_image.py index 08f051b9..d37dd951 100644 --- a/trollimage/tests/test_image.py +++ b/trollimage/tests/test_image.py @@ -1456,7 +1456,7 @@ def test_gamma(self): def test_crude_stretch(self): """Check crude stretching.""" - arr = np.arange(75).reshape(5, 5, 3) + arr = np.arange(75, dtype=np.float32).reshape(5, 5, 3) data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'], coords={'bands': ['R', 'G', 'B']}) img = xrimage.XRImage(data) @@ -1467,11 +1467,15 @@ def test_crude_stretch(self): enhs = img.data.attrs['enhancement_history'][0] scale_expected = np.array([0.01388889, 0.01388889, 0.01388889]) offset_expected = np.array([0., -0.01388889, -0.02777778]) + assert img.data.dtype == np.float32 np.testing.assert_allclose(enhs['scale'].values, scale_expected) np.testing.assert_allclose(enhs['offset'].values, offset_expected) - np.testing.assert_allclose(red, arr[:, :, 0] / 72.) - np.testing.assert_allclose(green, (arr[:, :, 1] - 1.) / (73. - 1.)) - np.testing.assert_allclose(blue, (arr[:, :, 2] - 2.) / (74. - 2.)) + expected_red = arr[:, :, 0] / 72. + np.testing.assert_allclose(red, expected_red.astype(np.float32), rtol=1e-6) + expected_green = (arr[:, :, 1] - 1.) / (73. - 1.) + np.testing.assert_allclose(green, expected_green.astype(np.float32), rtol=1e-6) + expected_blue = (arr[:, :, 2] - 2.) / (74. - 2.) + np.testing.assert_allclose(blue, expected_blue.astype(np.float32), rtol=1e-6) arr = np.arange(75).reshape(5, 5, 3).astype(float) data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'], diff --git a/trollimage/xrimage.py b/trollimage/xrimage.py index 54833703..d6cc34cd 100644 --- a/trollimage/xrimage.py +++ b/trollimage/xrimage.py @@ -1081,13 +1081,13 @@ def crude_stretch(self, min_stretch=None, max_stretch=None): delta = (max_stretch - min_stretch) if isinstance(delta, xr.DataArray): # fillna if delta is NaN - scale_factor = (1.0 / delta).fillna(0) + scale_factor = (1.0 / delta).fillna(0).astype(self.data.dtype) else: - scale_factor = 1.0 / delta + scale_factor = self.data.dtype.type(1.0 / delta) attrs = self.data.attrs offset = -min_stretch * scale_factor - self.data *= scale_factor.astype(self.data.dtype) - self.data += offset.astype(self.data.dtype) + self.data *= scale_factor + self.data += offset self.data.attrs = attrs self.data.attrs.setdefault('enhancement_history', []).append({'scale': scale_factor, 'offset': offset}) From b1d8d6f3a8df46c8331320e18a8598418eb947c6 Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Thu, 16 Nov 2023 12:49:24 +0200 Subject: [PATCH 3/7] Use float32 data for scale/offset attribute tests --- trollimage/tests/test_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trollimage/tests/test_image.py b/trollimage/tests/test_image.py index d37dd951..caad947d 100644 --- a/trollimage/tests/test_image.py +++ b/trollimage/tests/test_image.py @@ -2329,10 +2329,10 @@ class TestXRImageSaveScaleOffset: def setup_method(self) -> None: """Set up the test case.""" from trollimage import xrimage - data = xr.DataArray(np.arange(25).reshape(5, 5, 1), dims=[ + data = xr.DataArray(np.arange(25, dtype=np.float32).reshape(5, 5, 1), dims=[ 'y', 'x', 'bands'], coords={'bands': ['L']}) self.img = xrimage.XRImage(data) - rgb_data = xr.DataArray(np.arange(3 * 25).reshape(5, 5, 3), dims=[ + rgb_data = xr.DataArray(np.arange(3 * 25, dtype=np.float32).reshape(5, 5, 3), dims=[ 'y', 'x', 'bands'], coords={'bands': ['R', 'G', 'B']}) self.rgb_img = xrimage.XRImage(rgb_data) From 5238b57c6346371a587ab90dc9ec8a9e49c89656 Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Thu, 16 Nov 2023 13:56:12 +0200 Subject: [PATCH 4/7] Ensure min and max stretch have the same dtype as the data --- trollimage/xrimage.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trollimage/xrimage.py b/trollimage/xrimage.py index d6cc34cd..d79c1b7c 100644 --- a/trollimage/xrimage.py +++ b/trollimage/xrimage.py @@ -1078,6 +1078,9 @@ def crude_stretch(self, min_stretch=None, max_stretch=None): if isinstance(max_stretch, (list, tuple)): max_stretch = self.xrify_tuples(max_stretch) + min_stretch = min_stretch.astype(self.data.dtype) + max_stretch = max_stretch.astype(self.data.dtype) + delta = (max_stretch - min_stretch) if isinstance(delta, xr.DataArray): # fillna if delta is NaN From 4b86901b34257c756e5f768c80640d86de237fcf Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Thu, 16 Nov 2023 15:28:53 +0200 Subject: [PATCH 5/7] Handle scalar stretch values --- trollimage/xrimage.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/trollimage/xrimage.py b/trollimage/xrimage.py index d79c1b7c..9976710f 100644 --- a/trollimage/xrimage.py +++ b/trollimage/xrimage.py @@ -1078,8 +1078,12 @@ def crude_stretch(self, min_stretch=None, max_stretch=None): if isinstance(max_stretch, (list, tuple)): max_stretch = self.xrify_tuples(max_stretch) - min_stretch = min_stretch.astype(self.data.dtype) - max_stretch = max_stretch.astype(self.data.dtype) + try: + min_stretch = min_stretch.astype(self.data.dtype) + max_stretch = max_stretch.astype(self.data.dtype) + except AttributeError: + min_stretch = self.data.dtype.type(min_stretch) + max_stretch = self.data.dtype.type(max_stretch) delta = (max_stretch - min_stretch) if isinstance(delta, xr.DataArray): From aee94f8ac26dc5e2f5a410c02a5cef06905a7e05 Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Thu, 16 Nov 2023 16:08:03 +0200 Subject: [PATCH 6/7] Refactor crude_stretch() --- trollimage/xrimage.py | 44 +++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/trollimage/xrimage.py b/trollimage/xrimage.py index 9976710f..b66384b2 100644 --- a/trollimage/xrimage.py +++ b/trollimage/xrimage.py @@ -1066,38 +1066,42 @@ def crude_stretch(self, min_stretch=None, max_stretch=None): normalizes to the [0,1] range. """ - if min_stretch is None: - non_band_dims = tuple(x for x in self.data.dims if x != 'bands') - min_stretch = self.data.min(dim=non_band_dims) - if max_stretch is None: + min_stretch = self._check_stretch_value(min_stretch, kind='min') + max_stretch = self._check_stretch_value(max_stretch, kind='max') + scale_factor = self._get_scale_factor(min_stretch, max_stretch) + + attrs = self.data.attrs + offset = -min_stretch * scale_factor + self.data *= scale_factor + self.data += offset + self.data.attrs = attrs + self.data.attrs.setdefault('enhancement_history', []).append({'scale': scale_factor, + 'offset': offset}) + + def _check_stretch_value(self, val, kind='min'): + if val is None: non_band_dims = tuple(x for x in self.data.dims if x != 'bands') - max_stretch = self.data.max(dim=non_band_dims) + val = getattr(self.data, kind)(dim=non_band_dims) - if isinstance(min_stretch, (list, tuple)): - min_stretch = self.xrify_tuples(min_stretch) - if isinstance(max_stretch, (list, tuple)): - max_stretch = self.xrify_tuples(max_stretch) + if isinstance(val, (list, tuple)): + val = self.xrify_tuples(val) try: - min_stretch = min_stretch.astype(self.data.dtype) - max_stretch = max_stretch.astype(self.data.dtype) + val = val.astype(self.data.dtype) except AttributeError: - min_stretch = self.data.dtype.type(min_stretch) - max_stretch = self.data.dtype.type(max_stretch) + val = self.data.dtype.type(val) + + return val + def _get_scale_factor(self, min_stretch, max_stretch): delta = (max_stretch - min_stretch) if isinstance(delta, xr.DataArray): # fillna if delta is NaN scale_factor = (1.0 / delta).fillna(0).astype(self.data.dtype) else: scale_factor = self.data.dtype.type(1.0 / delta) - attrs = self.data.attrs - offset = -min_stretch * scale_factor - self.data *= scale_factor - self.data += offset - self.data.attrs = attrs - self.data.attrs.setdefault('enhancement_history', []).append({'scale': scale_factor, - 'offset': offset}) + + return scale_factor def stretch_hist_equalize(self, approximate=False): """Stretch the current image's colors through histogram equalization. From afc693ae4dab2b60e37ae93f2858ba4ee563f594 Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Fri, 17 Nov 2023 08:56:22 +0200 Subject: [PATCH 7/7] Handle integer data in crude_stretch, refactor tests --- trollimage/tests/test_image.py | 11 +++++++++++ trollimage/xrimage.py | 13 +++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/trollimage/tests/test_image.py b/trollimage/tests/test_image.py index caad947d..88144a32 100644 --- a/trollimage/tests/test_image.py +++ b/trollimage/tests/test_image.py @@ -1477,13 +1477,24 @@ def test_crude_stretch(self): expected_blue = (arr[:, :, 2] - 2.) / (74. - 2.) np.testing.assert_allclose(blue, expected_blue.astype(np.float32), rtol=1e-6) + def test_crude_stretch_with_limits(self): arr = np.arange(75).reshape(5, 5, 3).astype(float) data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'], coords={'bands': ['R', 'G', 'B']}) img = xrimage.XRImage(data) img.crude_stretch(0, 74) + assert img.data.dtype == float np.testing.assert_allclose(img.data.values, arr / 74.) + def test_crude_stretch_integer_data(self): + arr = np.arange(75, dtype=int).reshape(5, 5, 3) + data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'], + coords={'bands': ['R', 'G', 'B']}) + img = xrimage.XRImage(data) + img.crude_stretch(0, 74) + assert img.data.dtype == np.float32 + np.testing.assert_allclose(img.data.values, arr.astype(np.float32) / 74., rtol=1e-6) + def test_invert(self): """Check inversion of the image.""" arr = np.arange(75).reshape(5, 5, 3) / 75. diff --git a/trollimage/xrimage.py b/trollimage/xrimage.py index b66384b2..4e7e6440 100644 --- a/trollimage/xrimage.py +++ b/trollimage/xrimage.py @@ -1072,8 +1072,7 @@ def crude_stretch(self, min_stretch=None, max_stretch=None): attrs = self.data.attrs offset = -min_stretch * scale_factor - self.data *= scale_factor - self.data += offset + self.data = np.multiply(self.data, scale_factor, dtype=scale_factor.dtype) + offset self.data.attrs = attrs self.data.attrs.setdefault('enhancement_history', []).append({'scale': scale_factor, 'offset': offset}) @@ -1095,14 +1094,20 @@ def _check_stretch_value(self, val, kind='min'): def _get_scale_factor(self, min_stretch, max_stretch): delta = (max_stretch - min_stretch) + dtype = self._infer_scale_factor_dtype() if isinstance(delta, xr.DataArray): # fillna if delta is NaN - scale_factor = (1.0 / delta).fillna(0).astype(self.data.dtype) + scale_factor = (1.0 / delta).fillna(0).astype(dtype) else: - scale_factor = self.data.dtype.type(1.0 / delta) + scale_factor = np.array(1.0 / delta, dtype=dtype) return scale_factor + def _infer_scale_factor_dtype(self): + if np.issubdtype(self.data.dtype, np.integer): + return np.float32 + return self.data.dtype + def stretch_hist_equalize(self, approximate=False): """Stretch the current image's colors through histogram equalization.