Skip to content

Commit

Permalink
Allow linear alpha stretching in certain conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
mraspaud committed Feb 14, 2024
1 parent 36d98aa commit a1ba558
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 21 deletions.
75 changes: 75 additions & 0 deletions trollimage/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,6 +1697,81 @@ def test_linear_stretch_does_not_affect_alpha(self, dtype):

np.testing.assert_allclose(img.data.values, res, atol=1.e-6)

@pytest.mark.parametrize("dtype", (np.float32, np.float64, float))
def test_linear_stretch_does_not_affect_alpha_with_partial_cutoffs(self, dtype):
"""Test linear stretching with cutoffs."""
arr = np.arange(100, dtype=dtype).reshape(5, 5, 4) / 74.
arr[:, :, -1] = 1 # alpha channel, fully opaque
data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'],
coords={'bands': ['R', 'G', 'B', 'A']})
img = xrimage.XRImage(data)
img.stretch_linear([(0.005, 0.005), (0.005, 0.005), (0.005, 0.005)])
assert img.data.dtype == dtype
res = np.array([[[-0.005051, -0.005051, -0.005051, 1.],
[0.037037, 0.037037, 0.037037, 1.],
[0.079125, 0.079125, 0.079125, 1.],
[0.121212, 0.121212, 0.121212, 1.],
[0.1633, 0.1633, 0.1633, 1.]],
[[0.205387, 0.205387, 0.205387, 1.],
[0.247475, 0.247475, 0.247475, 1.],
[0.289562, 0.289562, 0.289562, 1.],
[0.33165, 0.33165, 0.33165, 1.],
[0.373737, 0.373737, 0.373737, 1.]],
[[0.415825, 0.415825, 0.415825, 1.],
[0.457912, 0.457912, 0.457912, 1.],
[0.5, 0.5, 0.5, 1.],
[0.542088, 0.542088, 0.542088, 1.],
[0.584175, 0.584175, 0.584175, 1.]],
[[0.626263, 0.626263, 0.626263, 1.],
[0.66835, 0.66835, 0.66835, 1.],
[0.710438, 0.710438, 0.710438, 1.],
[0.752525, 0.752525, 0.752525, 1.],
[0.794613, 0.794613, 0.794613, 1.]],
[[0.8367, 0.8367, 0.8367, 1.],
[0.878788, 0.878788, 0.878788, 1.],
[0.920875, 0.920875, 0.920875, 1.],
[0.962963, 0.962963, 0.962963, 1.],
[1.005051, 1.005051, 1.005051, 1.]]], dtype=dtype)

np.testing.assert_allclose(img.data.values, res, atol=1.e-6)

@pytest.mark.parametrize("dtype", (np.float32, np.float64, float))
def test_linear_stretch_does_affect_alpha_with_explicit_cutoffs(self, dtype):
"""Test linear stretching with full explicit cutoffs."""
arr = np.arange(100, dtype=dtype).reshape(5, 5, 4) / 74.
data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'],
coords={'bands': ['R', 'G', 'B', 'A']})
img = xrimage.XRImage(data)
img.stretch_linear([(0.005, 0.005), (0.005, 0.005), (0.005, 0.005), (0.005, 0.005)])
assert img.data.dtype == dtype
res = np.array([[[-0.005051, -0.005051, -0.005051, -0.005051],
[0.037037, 0.037037, 0.037037, 0.037037],
[0.079125, 0.079125, 0.079125, 0.079125],
[0.121212, 0.121212, 0.121212, 0.121212],
[0.1633, 0.1633, 0.1633, 0.1633]],
[[0.205387, 0.205387, 0.205387, 0.205387],
[0.247475, 0.247475, 0.247475, 0.247475],
[0.289562, 0.289562, 0.289562, 0.289562],
[0.33165, 0.33165, 0.33165, 0.33165],
[0.373737, 0.373737, 0.373737, 0.373737]],
[[0.415825, 0.415825, 0.415825, 0.415825],
[0.457912, 0.457912, 0.457912, 0.457912],
[0.5, 0.5, 0.5, 0.5],
[0.542088, 0.542088, 0.542088, 0.542088],
[0.584175, 0.584175, 0.584175, 0.584175]],
[[0.626263, 0.626263, 0.626263, 0.626263],
[0.66835, 0.66835, 0.66835, 0.66835],
[0.710438, 0.710438, 0.710438, 0.710438],
[0.752525, 0.752525, 0.752525, 0.752525],
[0.794613, 0.794613, 0.794613, 0.794613]],
[[0.8367, 0.8367, 0.8367, 0.8367],
[0.878788, 0.878788, 0.878788, 0.878788],
[0.920875, 0.920875, 0.920875, 0.920875],
[0.962963, 0.962963, 0.962963, 0.962963],
[1.005051, 1.005051, 1.005051, 1.005051]]], dtype=dtype)

np.testing.assert_allclose(img.data.values, res, atol=1.e-6)

@pytest.mark.parametrize(("dtype", "max_val", "exp_min", "exp_max"),
((np.uint8, 255, -0.005358012691140175, 1.0053772069513798),
(np.int8, 127, -0.004926108196377754, 1.0058689523488282),
Expand Down
66 changes: 45 additions & 21 deletions trollimage/xrimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,30 +1004,19 @@ def stretch(self, stretch="crude", **kwargs):
else:
raise TypeError("Stretch parameter must be a string or a tuple.")

@staticmethod
def _compute_quantile(data, dims, cutoffs):
"""Compute quantile for stretch_linear.
Dask delayed functions need to be non-internal functions (created
inside a function) to be serializable on a multi-process scheduler.
Quantile requires the data to be loaded since it not supported on
dask arrays yet.
"""
# numpy doesn't get a 'quantile' function until 1.15
# for better backwards compatibility we use xarray's version
data_arr = xr.DataArray(data, dims=dims)
# delayed will provide us the fully computed xarray with ndarray
left, right = data_arr.quantile([cutoffs[0], 1. - cutoffs[1]], dim=['x', 'y'])
logger.debug("Interval: left=%s, right=%s", str(left), str(right))
return left.data, right.data

def stretch_linear(self, cutoffs=(0.005, 0.005)):
"""Stretch linearly the contrast of the current image.
Use *cutoffs* for left and right trimming.
If the cutoffs are just a tuple or list of two scalar, all the channels except the alpha channel
will be stretched with the cutoffs.
If the cutoffs are a sequence of tuples/lists of two scalar:
- if there is the same number of tuples/lists as channels, each channel will be stretched with the respective
cutoff.
- if there is one less tuples/lists as channels, the same applies, except for the alpha channel which will
not be stretched.
"""
logger.debug("Perform a linear contrast stretch.")

Expand All @@ -1047,12 +1036,20 @@ def _get_left_and_right_quantiles_for_linear_stretch(self, cutoffs):
cutoff_type = self.data.dtype

data = self.data
if 'A' in self.data.coords['bands'].values:
nb_bands = len(data.coords["bands"])

dont_stretch_alpha = ('A' in self.data.coords['bands'].values and
(np.isscalar(cutoffs[0]) or len(cutoffs) == nb_bands - 1))

if np.isscalar(cutoffs[0]):
cutoffs = [cutoffs] * nb_bands

if dont_stretch_alpha:
data = self.data.sel(bands=self.data.coords['bands'].values[:-1])

left_data, right_data = self._get_left_and_right_quantiles_without_alpha(data, cutoffs, cutoff_type)

if 'A' in self.data.coords['bands'].values:
if dont_stretch_alpha:
left_data = np.hstack([left_data, np.array([0])])
right_data = np.hstack([right_data, np.array([1])])
left = xr.DataArray(left_data, dims=('bands',),
Expand All @@ -1071,6 +1068,33 @@ def _get_left_and_right_quantiles_without_alpha(self, data, cutoffs, cutoff_type
dtype=cutoff_type)
return left_data, right_data

@staticmethod
def _compute_quantile(data, dims, cutoffs):
"""Compute quantile for stretch_linear.
Dask delayed functions need to be non-internal functions (created
inside a function) to be serializable on a multi-process scheduler.
Quantile requires the data to be loaded since it not supported on
dask arrays yet.
"""
# numpy doesn't get a 'quantile' function until 1.15
# for better backwards compatibility we use xarray's version
data_arr = xr.DataArray(data, dims=dims)
# delayed will provide us the fully computed xarray with ndarray
nb_bands = len(data_arr.coords["bands"])

left = []
right = []
for i in range(nb_bands):
left_i, right_i = data_arr.isel(bands=i).quantile([cutoffs[i][0], 1-cutoffs[i][1]])
left.append(left_i)
right.append(right_i)

logger.debug("Interval: left=%s, right=%s", str(left), str(right))
return np.array(left), np.array(right)

def crude_stretch(self, min_stretch=None, max_stretch=None):
"""Perform simple linear stretching.
Expand Down

0 comments on commit a1ba558

Please sign in to comment.