Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow linear alpha stretching in certain conditions #163

Merged
merged 2 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions trollimage/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,6 +1697,81 @@ def test_linear_stretch_does_not_affect_alpha(self, dtype):

np.testing.assert_allclose(img.data.values, res, atol=1.e-6)

@pytest.mark.parametrize("dtype", (np.float32, np.float64, float))
def test_linear_stretch_does_not_affect_alpha_with_partial_cutoffs(self, dtype):
"""Test linear stretching with cutoffs."""
arr = np.arange(100, dtype=dtype).reshape(5, 5, 4) / 74.
arr[:, :, -1] = 1 # alpha channel, fully opaque
data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'],
coords={'bands': ['R', 'G', 'B', 'A']})
img = xrimage.XRImage(data)
img.stretch_linear([(0.005, 0.005), (0.005, 0.005), (0.005, 0.005)])
assert img.data.dtype == dtype
res = np.array([[[-0.005051, -0.005051, -0.005051, 1.],
[0.037037, 0.037037, 0.037037, 1.],
[0.079125, 0.079125, 0.079125, 1.],
[0.121212, 0.121212, 0.121212, 1.],
[0.1633, 0.1633, 0.1633, 1.]],
[[0.205387, 0.205387, 0.205387, 1.],
[0.247475, 0.247475, 0.247475, 1.],
[0.289562, 0.289562, 0.289562, 1.],
[0.33165, 0.33165, 0.33165, 1.],
[0.373737, 0.373737, 0.373737, 1.]],
[[0.415825, 0.415825, 0.415825, 1.],
[0.457912, 0.457912, 0.457912, 1.],
[0.5, 0.5, 0.5, 1.],
[0.542088, 0.542088, 0.542088, 1.],
[0.584175, 0.584175, 0.584175, 1.]],
[[0.626263, 0.626263, 0.626263, 1.],
[0.66835, 0.66835, 0.66835, 1.],
[0.710438, 0.710438, 0.710438, 1.],
[0.752525, 0.752525, 0.752525, 1.],
[0.794613, 0.794613, 0.794613, 1.]],
[[0.8367, 0.8367, 0.8367, 1.],
[0.878788, 0.878788, 0.878788, 1.],
[0.920875, 0.920875, 0.920875, 1.],
[0.962963, 0.962963, 0.962963, 1.],
[1.005051, 1.005051, 1.005051, 1.]]], dtype=dtype)

np.testing.assert_allclose(img.data.values, res, atol=1.e-6)

@pytest.mark.parametrize("dtype", (np.float32, np.float64, float))
def test_linear_stretch_does_affect_alpha_with_explicit_cutoffs(self, dtype):
"""Test linear stretching with full explicit cutoffs."""
arr = np.arange(100, dtype=dtype).reshape(5, 5, 4) / 74.
data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'],
coords={'bands': ['R', 'G', 'B', 'A']})
img = xrimage.XRImage(data)
img.stretch_linear([(0.005, 0.005), (0.005, 0.005), (0.005, 0.005), (0.005, 0.005)])
assert img.data.dtype == dtype
res = np.array([[[-0.005051, -0.005051, -0.005051, -0.005051],
[0.037037, 0.037037, 0.037037, 0.037037],
[0.079125, 0.079125, 0.079125, 0.079125],
[0.121212, 0.121212, 0.121212, 0.121212],
[0.1633, 0.1633, 0.1633, 0.1633]],
[[0.205387, 0.205387, 0.205387, 0.205387],
[0.247475, 0.247475, 0.247475, 0.247475],
[0.289562, 0.289562, 0.289562, 0.289562],
[0.33165, 0.33165, 0.33165, 0.33165],
[0.373737, 0.373737, 0.373737, 0.373737]],
[[0.415825, 0.415825, 0.415825, 0.415825],
[0.457912, 0.457912, 0.457912, 0.457912],
[0.5, 0.5, 0.5, 0.5],
[0.542088, 0.542088, 0.542088, 0.542088],
[0.584175, 0.584175, 0.584175, 0.584175]],
[[0.626263, 0.626263, 0.626263, 0.626263],
[0.66835, 0.66835, 0.66835, 0.66835],
[0.710438, 0.710438, 0.710438, 0.710438],
[0.752525, 0.752525, 0.752525, 0.752525],
[0.794613, 0.794613, 0.794613, 0.794613]],
[[0.8367, 0.8367, 0.8367, 0.8367],
[0.878788, 0.878788, 0.878788, 0.878788],
[0.920875, 0.920875, 0.920875, 0.920875],
[0.962963, 0.962963, 0.962963, 0.962963],
[1.005051, 1.005051, 1.005051, 1.005051]]], dtype=dtype)

np.testing.assert_allclose(img.data.values, res, atol=1.e-6)

@pytest.mark.parametrize(("dtype", "max_val", "exp_min", "exp_max"),
((np.uint8, 255, -0.005358012691140175, 1.0053772069513798),
(np.int8, 127, -0.004926108196377754, 1.0058689523488282),
Expand Down
67 changes: 46 additions & 21 deletions trollimage/xrimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,30 +1004,20 @@ def stretch(self, stretch="crude", **kwargs):
else:
raise TypeError("Stretch parameter must be a string or a tuple.")

@staticmethod
def _compute_quantile(data, dims, cutoffs):
"""Compute quantile for stretch_linear.

Dask delayed functions need to be non-internal functions (created
inside a function) to be serializable on a multi-process scheduler.

Quantile requires the data to be loaded since it not supported on
dask arrays yet.

"""
# numpy doesn't get a 'quantile' function until 1.15
# for better backwards compatibility we use xarray's version
data_arr = xr.DataArray(data, dims=dims)
# delayed will provide us the fully computed xarray with ndarray
left, right = data_arr.quantile([cutoffs[0], 1. - cutoffs[1]], dim=['x', 'y'])
logger.debug("Interval: left=%s, right=%s", str(left), str(right))
return left.data, right.data

def stretch_linear(self, cutoffs=(0.005, 0.005)):
"""Stretch linearly the contrast of the current image.

Use *cutoffs* for left and right trimming.

If the cutoffs are just a tuple or list of two scalars, all the
channels except the alpha channel will be stretched with the cutoffs.
If the cutoffs are a sequence of tuples/lists of two scalars then:

- if there are the same number of tuples/lists as channels, each channel will be stretched with the respective
cutoff.
- if there is one less tuple/list as channels, the same applies, except for the alpha channel which will
not be stretched.

"""
logger.debug("Perform a linear contrast stretch.")

Expand All @@ -1047,12 +1037,20 @@ def _get_left_and_right_quantiles_for_linear_stretch(self, cutoffs):
cutoff_type = self.data.dtype

data = self.data
if 'A' in self.data.coords['bands'].values:
nb_bands = len(data.coords["bands"])

dont_stretch_alpha = ('A' in self.data.coords['bands'].values and
(np.isscalar(cutoffs[0]) or len(cutoffs) == nb_bands - 1))

if np.isscalar(cutoffs[0]):
cutoffs = [cutoffs] * nb_bands

if dont_stretch_alpha:
data = self.data.sel(bands=self.data.coords['bands'].values[:-1])

left_data, right_data = self._get_left_and_right_quantiles_without_alpha(data, cutoffs, cutoff_type)

if 'A' in self.data.coords['bands'].values:
if dont_stretch_alpha:
left_data = np.hstack([left_data, np.array([0])])
right_data = np.hstack([right_data, np.array([1])])
left = xr.DataArray(left_data, dims=('bands',),
Expand All @@ -1071,6 +1069,33 @@ def _get_left_and_right_quantiles_without_alpha(self, data, cutoffs, cutoff_type
dtype=cutoff_type)
return left_data, right_data

@staticmethod
def _compute_quantile(data, dims, cutoffs):
"""Compute quantile for stretch_linear.

Dask delayed functions need to be non-internal functions (created
inside a function) to be serializable on a multi-process scheduler.

Quantile requires the data to be loaded since it not supported on
dask arrays yet.

"""
# numpy doesn't get a 'quantile' function until 1.15
# for better backwards compatibility we use xarray's version
data_arr = xr.DataArray(data, dims=dims)
# delayed will provide us the fully computed xarray with ndarray
nb_bands = len(data_arr.coords["bands"])

left = []
right = []
for i in range(nb_bands):
left_i, right_i = data_arr.isel(bands=i).quantile([cutoffs[i][0], 1-cutoffs[i][1]])
left.append(left_i)
right.append(right_i)

logger.debug("Interval: left=%s, right=%s", str(left), str(right))
return np.array(left), np.array(right)

def crude_stretch(self, min_stretch=None, max_stretch=None):
"""Perform simple linear stretching.

Expand Down
Loading