Skip to content

Commit

Permalink
Merge pull request #186 from pnuu/bugfix-invert-dtype-promotion
Browse files Browse the repository at this point in the history
Fix dtype promotion in channel inversion
  • Loading branch information
mraspaud authored Oct 14, 2024
2 parents c384f5b + 2ce875e commit 81b80e2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
11 changes: 8 additions & 3 deletions trollimage/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,8 +1595,8 @@ def test_crude_stretch_integer_data(self, dtype, max_stretch):
np.testing.assert_allclose(img.data.values, arr.astype(np.float32) / max_stretch, rtol=1e-6)

@pytest.mark.parametrize("dtype", (np.float32, np.float64, float))
def test_invert(self, dtype):
"""Check inversion of the image."""
def test_invert_single_parameter(self, dtype):
"""Check inversion of the image for single inversion parameter."""
arr = np.arange(75, dtype=dtype).reshape(5, 5, 3) / 75.
data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'],
coords={'bands': ['R', 'G', 'B']})
Expand All @@ -1608,7 +1608,11 @@ def test_invert(self, dtype):
assert img.data.dtype == dtype
assert np.allclose(img.data.values, 1 - arr)

data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'],
@pytest.mark.parametrize("dtype", (np.float32, np.float64, float))
def test_invert_parameter_for_each_channel(self, dtype):
"""Check inversion of the image for single inversion parameter."""
arr = np.arange(75, dtype=dtype).reshape(5, 5, 3) / 75.
data = xr.DataArray(arr, dims=['y', 'x', 'bands'],
coords={'bands': ['R', 'G', 'B']})
img = xrimage.XRImage(data)

Expand All @@ -1618,6 +1622,7 @@ def test_invert(self, dtype):
scale = xr.DataArray(np.array([-1, 1, -1]), dims=['bands'],
coords={'bands': ['R', 'G', 'B']})
np.testing.assert_allclose(img.data.values, (data * scale + offset).values)
assert img.data.dtype == dtype

@pytest.mark.parametrize("dtype", (np.float32, np.float64, float))
def test_linear_stretch(self, dtype):
Expand Down
3 changes: 1 addition & 2 deletions trollimage/xrimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,6 @@ def _check_stretch_value(self, val, kind='min'):
if val is None:
non_band_dims = tuple(x for x in self.data.dims if x != 'bands')
val = getattr(self.data, kind)(dim=non_band_dims)

if isinstance(val, (list, tuple)):
val = self.xrify_tuples(val)

Expand Down Expand Up @@ -1302,7 +1301,7 @@ def invert(self, invert=True):
logger.debug("Applying invert with parameters %s", str(invert))
if isinstance(invert, (tuple, list)):
invert = self.xrify_tuples(invert)
offset = invert.astype(int)
offset = invert.astype(self.data.dtype)
scale = (-1) ** offset
elif invert:
offset = 1
Expand Down

0 comments on commit 81b80e2

Please sign in to comment.