From a1ba5586b62a431f591dca9d1ff724b6ee8ad1d5 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 14 Feb 2024 11:12:33 +0100 Subject: [PATCH 1/2] Allow linear alpha stretching in certain conditions --- trollimage/tests/test_image.py | 75 ++++++++++++++++++++++++++++++++++ trollimage/xrimage.py | 66 ++++++++++++++++++++---------- 2 files changed, 120 insertions(+), 21 deletions(-) diff --git a/trollimage/tests/test_image.py b/trollimage/tests/test_image.py index 493cc0e..294e932 100644 --- a/trollimage/tests/test_image.py +++ b/trollimage/tests/test_image.py @@ -1697,6 +1697,81 @@ def test_linear_stretch_does_not_affect_alpha(self, dtype): np.testing.assert_allclose(img.data.values, res, atol=1.e-6) + @pytest.mark.parametrize("dtype", (np.float32, np.float64, float)) + def test_linear_stretch_does_not_affect_alpha_with_partial_cutoffs(self, dtype): + """Test linear stretching with cutoffs.""" + arr = np.arange(100, dtype=dtype).reshape(5, 5, 4) / 74. + arr[:, :, -1] = 1 # alpha channel, fully opaque + data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'], + coords={'bands': ['R', 'G', 'B', 'A']}) + img = xrimage.XRImage(data) + img.stretch_linear([(0.005, 0.005), (0.005, 0.005), (0.005, 0.005)]) + assert img.data.dtype == dtype + res = np.array([[[-0.005051, -0.005051, -0.005051, 1.], + [0.037037, 0.037037, 0.037037, 1.], + [0.079125, 0.079125, 0.079125, 1.], + [0.121212, 0.121212, 0.121212, 1.], + [0.1633, 0.1633, 0.1633, 1.]], + [[0.205387, 0.205387, 0.205387, 1.], + [0.247475, 0.247475, 0.247475, 1.], + [0.289562, 0.289562, 0.289562, 1.], + [0.33165, 0.33165, 0.33165, 1.], + [0.373737, 0.373737, 0.373737, 1.]], + [[0.415825, 0.415825, 0.415825, 1.], + [0.457912, 0.457912, 0.457912, 1.], + [0.5, 0.5, 0.5, 1.], + [0.542088, 0.542088, 0.542088, 1.], + [0.584175, 0.584175, 0.584175, 1.]], + [[0.626263, 0.626263, 0.626263, 1.], + [0.66835, 0.66835, 0.66835, 1.], + [0.710438, 0.710438, 0.710438, 1.], + [0.752525, 0.752525, 0.752525, 1.], + [0.794613, 0.794613, 0.794613, 1.]], + [[0.8367, 0.8367, 0.8367, 1.], + [0.878788, 0.878788, 0.878788, 1.], + [0.920875, 0.920875, 0.920875, 1.], + [0.962963, 0.962963, 0.962963, 1.], + [1.005051, 1.005051, 1.005051, 1.]]], dtype=dtype) + + np.testing.assert_allclose(img.data.values, res, atol=1.e-6) + + @pytest.mark.parametrize("dtype", (np.float32, np.float64, float)) + def test_linear_stretch_does_affect_alpha_with_explicit_cutoffs(self, dtype): + """Test linear stretching with full explicit cutoffs.""" + arr = np.arange(100, dtype=dtype).reshape(5, 5, 4) / 74. + data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'], + coords={'bands': ['R', 'G', 'B', 'A']}) + img = xrimage.XRImage(data) + img.stretch_linear([(0.005, 0.005), (0.005, 0.005), (0.005, 0.005), (0.005, 0.005)]) + assert img.data.dtype == dtype + res = np.array([[[-0.005051, -0.005051, -0.005051, -0.005051], + [0.037037, 0.037037, 0.037037, 0.037037], + [0.079125, 0.079125, 0.079125, 0.079125], + [0.121212, 0.121212, 0.121212, 0.121212], + [0.1633, 0.1633, 0.1633, 0.1633]], + [[0.205387, 0.205387, 0.205387, 0.205387], + [0.247475, 0.247475, 0.247475, 0.247475], + [0.289562, 0.289562, 0.289562, 0.289562], + [0.33165, 0.33165, 0.33165, 0.33165], + [0.373737, 0.373737, 0.373737, 0.373737]], + [[0.415825, 0.415825, 0.415825, 0.415825], + [0.457912, 0.457912, 0.457912, 0.457912], + [0.5, 0.5, 0.5, 0.5], + [0.542088, 0.542088, 0.542088, 0.542088], + [0.584175, 0.584175, 0.584175, 0.584175]], + [[0.626263, 0.626263, 0.626263, 0.626263], + [0.66835, 0.66835, 0.66835, 0.66835], + [0.710438, 0.710438, 0.710438, 0.710438], + [0.752525, 0.752525, 0.752525, 0.752525], + [0.794613, 0.794613, 0.794613, 0.794613]], + [[0.8367, 0.8367, 0.8367, 0.8367], + [0.878788, 0.878788, 0.878788, 0.878788], + [0.920875, 0.920875, 0.920875, 0.920875], + [0.962963, 0.962963, 0.962963, 0.962963], + [1.005051, 1.005051, 1.005051, 1.005051]]], dtype=dtype) + + np.testing.assert_allclose(img.data.values, res, atol=1.e-6) + @pytest.mark.parametrize(("dtype", "max_val", "exp_min", "exp_max"), ((np.uint8, 255, -0.005358012691140175, 1.0053772069513798), (np.int8, 127, -0.004926108196377754, 1.0058689523488282), diff --git a/trollimage/xrimage.py b/trollimage/xrimage.py index 2f6d53b..4fd1532 100644 --- a/trollimage/xrimage.py +++ b/trollimage/xrimage.py @@ -1004,30 +1004,19 @@ def stretch(self, stretch="crude", **kwargs): else: raise TypeError("Stretch parameter must be a string or a tuple.") - @staticmethod - def _compute_quantile(data, dims, cutoffs): - """Compute quantile for stretch_linear. - - Dask delayed functions need to be non-internal functions (created - inside a function) to be serializable on a multi-process scheduler. - - Quantile requires the data to be loaded since it not supported on - dask arrays yet. - - """ - # numpy doesn't get a 'quantile' function until 1.15 - # for better backwards compatibility we use xarray's version - data_arr = xr.DataArray(data, dims=dims) - # delayed will provide us the fully computed xarray with ndarray - left, right = data_arr.quantile([cutoffs[0], 1. - cutoffs[1]], dim=['x', 'y']) - logger.debug("Interval: left=%s, right=%s", str(left), str(right)) - return left.data, right.data - def stretch_linear(self, cutoffs=(0.005, 0.005)): """Stretch linearly the contrast of the current image. Use *cutoffs* for left and right trimming. + If the cutoffs are just a tuple or list of two scalar, all the channels except the alpha channel + will be stretched with the cutoffs. + If the cutoffs are a sequence of tuples/lists of two scalar: + - if there is the same number of tuples/lists as channels, each channel will be stretched with the respective + cutoff. + - if there is one less tuples/lists as channels, the same applies, except for the alpha channel which will + not be stretched. + """ logger.debug("Perform a linear contrast stretch.") @@ -1047,12 +1036,20 @@ def _get_left_and_right_quantiles_for_linear_stretch(self, cutoffs): cutoff_type = self.data.dtype data = self.data - if 'A' in self.data.coords['bands'].values: + nb_bands = len(data.coords["bands"]) + + dont_stretch_alpha = ('A' in self.data.coords['bands'].values and + (np.isscalar(cutoffs[0]) or len(cutoffs) == nb_bands - 1)) + + if np.isscalar(cutoffs[0]): + cutoffs = [cutoffs] * nb_bands + + if dont_stretch_alpha: data = self.data.sel(bands=self.data.coords['bands'].values[:-1]) left_data, right_data = self._get_left_and_right_quantiles_without_alpha(data, cutoffs, cutoff_type) - if 'A' in self.data.coords['bands'].values: + if dont_stretch_alpha: left_data = np.hstack([left_data, np.array([0])]) right_data = np.hstack([right_data, np.array([1])]) left = xr.DataArray(left_data, dims=('bands',), @@ -1071,6 +1068,33 @@ def _get_left_and_right_quantiles_without_alpha(self, data, cutoffs, cutoff_type dtype=cutoff_type) return left_data, right_data + @staticmethod + def _compute_quantile(data, dims, cutoffs): + """Compute quantile for stretch_linear. + + Dask delayed functions need to be non-internal functions (created + inside a function) to be serializable on a multi-process scheduler. + + Quantile requires the data to be loaded since it not supported on + dask arrays yet. + + """ + # numpy doesn't get a 'quantile' function until 1.15 + # for better backwards compatibility we use xarray's version + data_arr = xr.DataArray(data, dims=dims) + # delayed will provide us the fully computed xarray with ndarray + nb_bands = len(data_arr.coords["bands"]) + + left = [] + right = [] + for i in range(nb_bands): + left_i, right_i = data_arr.isel(bands=i).quantile([cutoffs[i][0], 1-cutoffs[i][1]]) + left.append(left_i) + right.append(right_i) + + logger.debug("Interval: left=%s, right=%s", str(left), str(right)) + return np.array(left), np.array(right) + def crude_stretch(self, min_stretch=None, max_stretch=None): """Perform simple linear stretching. From 97231c04ad44793a063f8399cfffa55cf6d100a7 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Wed, 14 Feb 2024 09:54:42 -0600 Subject: [PATCH 2/2] Fix docstring formatting for stretch_linear --- trollimage/xrimage.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/trollimage/xrimage.py b/trollimage/xrimage.py index 4fd1532..5dba8a5 100644 --- a/trollimage/xrimage.py +++ b/trollimage/xrimage.py @@ -1009,13 +1009,14 @@ def stretch_linear(self, cutoffs=(0.005, 0.005)): Use *cutoffs* for left and right trimming. - If the cutoffs are just a tuple or list of two scalar, all the channels except the alpha channel - will be stretched with the cutoffs. - If the cutoffs are a sequence of tuples/lists of two scalar: - - if there is the same number of tuples/lists as channels, each channel will be stretched with the respective - cutoff. - - if there is one less tuples/lists as channels, the same applies, except for the alpha channel which will - not be stretched. + If the cutoffs are just a tuple or list of two scalars, all the + channels except the alpha channel will be stretched with the cutoffs. + If the cutoffs are a sequence of tuples/lists of two scalars then: + + - if there are the same number of tuples/lists as channels, each channel will be stretched with the respective + cutoff. + - if there is one less tuple/list as channels, the same applies, except for the alpha channel which will + not be stretched. """ logger.debug("Perform a linear contrast stretch.")