Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respect the bounding_box in inverse transforms #498

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

- Force ``bounding_box`` to always be returned as a ``F`` ordered box. [#522]

- Fixed a bug where evaluating the inverse transform did not
respect the bounding box. [#498]

0.21.0 (2024-03-10)
-------------------

Expand Down
11 changes: 9 additions & 2 deletions gwcs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,13 @@
be returned in the ``(x, y)`` order, where for an image, ``x`` is the
horizontal coordinate and ``y`` is the vertical coordinate.
"""
world_arrays = self._add_units_input(world_arrays, self.backward_transform, self.output_frame)
try:
backward_transform = self.backward_transform
world_arrays = self._add_units_input(world_arrays,
backward_transform,
self.output_frame)
except NotImplementedError:
pass

Check warning on line 141 in gwcs/api.py

View check run for this annotation

Codecov / codecov/patch

gwcs/api.py#L140-L141

Added lines #L140 - L141 were not covered by tests

result = self.invert(*world_arrays, with_units=False)

Expand Down Expand Up @@ -312,8 +318,9 @@
"""
Convert world coordinates to pixel values.
"""
#args = high_level_objects_to_values(*world_objects, low_level_wcs=self)
#result = self.invert(*args)
result = self.invert(*world_objects, with_units=True)

if self.input_frame.naxes > 1:
first_res = result[0]
if not utils.isnumerical(first_res):
Expand Down
31 changes: 30 additions & 1 deletion gwcs/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@
transform = models.Mapping((2, 0, 1)) | celestial & wave_model | models.Mapping((1, 2, 0))

sky_frame = cf.CelestialFrame(axes_order=(2, 0),
reference_frame=coord.Galactic(), axes_names=("Longitude", "Latitude"))
reference_frame=coord.Galactic(),
axes_names=("Longitude", "Latitude"))
wave_frame = cf.SpectralFrame(axes_order=(1, ), unit=u.Hz, axes_names=("Frequency",))

frame = cf.CompositeFrame([sky_frame, wave_frame])
Expand Down Expand Up @@ -477,3 +478,31 @@
w.pixel_shape = (16, 32, 21, 11, 11, 2)

return w


def gwcs_simple_imaging_no_units():
shift_by_crpix = models.Shift(-2048) & models.Shift(-1024)
matrix = np.array([[1.290551569736E-05, 5.9525007864732E-06],

Check warning on line 485 in gwcs/examples.py

View check run for this annotation

Codecov / codecov/patch

gwcs/examples.py#L484-L485

Added lines #L484 - L485 were not covered by tests
[5.0226382102765E-06 , -1.2644844123757E-05]])
rotation = models.AffineTransformation2D(matrix,

Check warning on line 487 in gwcs/examples.py

View check run for this annotation

Codecov / codecov/patch

gwcs/examples.py#L487

Added line #L487 was not covered by tests
translation=[0, 0])

rotation.inverse = models.AffineTransformation2D(np.linalg.inv(matrix),

Check warning on line 490 in gwcs/examples.py

View check run for this annotation

Codecov / codecov/patch

gwcs/examples.py#L490

Added line #L490 was not covered by tests
translation=[0, 0])
tan = models.Pix2Sky_TAN()
celestial_rotation = models.RotateNative2Celestial(5.63056810618,

Check warning on line 493 in gwcs/examples.py

View check run for this annotation

Codecov / codecov/patch

gwcs/examples.py#L492-L493

Added lines #L492 - L493 were not covered by tests
-72.05457184279,
180)
det2sky = shift_by_crpix | rotation | tan | celestial_rotation
det2sky.name = "linear_transform"

Check warning on line 497 in gwcs/examples.py

View check run for this annotation

Codecov / codecov/patch

gwcs/examples.py#L496-L497

Added lines #L496 - L497 were not covered by tests

detector_frame = cf.Frame2D(name="detector", axes_names=("x", "y"),

Check warning on line 499 in gwcs/examples.py

View check run for this annotation

Codecov / codecov/patch

gwcs/examples.py#L499

Added line #L499 was not covered by tests
unit=(u.pix, u.pix))
sky_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), name='icrs',

Check warning on line 501 in gwcs/examples.py

View check run for this annotation

Codecov / codecov/patch

gwcs/examples.py#L501

Added line #L501 was not covered by tests
unit=(u.deg, u.deg))
pipeline = [(detector_frame, det2sky),

Check warning on line 503 in gwcs/examples.py

View check run for this annotation

Codecov / codecov/patch

gwcs/examples.py#L503

Added line #L503 was not covered by tests
(sky_frame, None)
]
w = wcs.WCS(pipeline)
w.bounding_box = ((2, 100), (5, 500))
return w

Check warning on line 508 in gwcs/examples.py

View check run for this annotation

Codecov / codecov/patch

gwcs/examples.py#L506-L508

Added lines #L506 - L508 were not covered by tests
9 changes: 7 additions & 2 deletions gwcs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
"""
import pytest

from .. import examples
from .. import geometry
from gwcs import examples
from gwcs import geometry


@pytest.fixture
Expand Down Expand Up @@ -141,3 +141,8 @@
@pytest.fixture
def cart_to_spher():
return geometry.CartesianToSpherical()


@pytest.fixture
def gwcs_simple_imaging_no_units():
return examples.gwcs_simple_imaging_no_units()

Check warning on line 148 in gwcs/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

gwcs/tests/conftest.py#L148

Added line #L148 was not covered by tests
17 changes: 4 additions & 13 deletions gwcs/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def wcs_ndim_types_units(request):

@fixture_all_wcses
def test_lowlevel_types(wcsobj):
pytest.importorskip("typeguard")
try:
# Skip this on older versions of astropy where it dosen't exist.
from astropy.wcs.wcsapi.tests.utils import validate_low_level_wcs_types
Expand Down Expand Up @@ -236,12 +235,12 @@ def test_world_axis_object_classes_4d(gwcs_4d_identity_units):
def _compare_frame_output(wc1, wc2):
if isinstance(wc1, coord.SkyCoord):
assert isinstance(wc1.frame, type(wc2.frame))
assert u.allclose(wc1.spherical.lon, wc2.spherical.lon)
assert u.allclose(wc1.spherical.lat, wc2.spherical.lat)
assert u.allclose(wc1.spherical.distance, wc2.spherical.distance)
assert u.allclose(wc1.spherical.lon, wc2.spherical.lon, equal_nan=True)
assert u.allclose(wc1.spherical.lat, wc2.spherical.lat, equal_nan=True)
assert u.allclose(wc1.spherical.distance, wc2.spherical.distance, equal_nan=True)

elif isinstance(wc1, u.Quantity):
assert u.allclose(wc1, wc2)
assert u.allclose(wc1, wc2, equal_nan=True)

elif isinstance(wc1, time.Time):
assert u.allclose((wc1 - wc2).to(u.s), 0*u.s)
Expand All @@ -258,12 +257,6 @@ def _compare_frame_output(wc1, wc2):

@fixture_all_wcses
def test_high_level_wrapper(wcsobj, request):
if request.node.callspec.params['wcsobj'] in ('gwcs_4d_identity_units', 'gwcs_stokes_lookup'):
pytest.importorskip("astropy", minversion="4.0dev0")

# Remove the bounding box because the type test is a little broken with the
# bounding box.
del wcsobj._pipeline[0].transform.bounding_box

hlvl = HighLevelWCSWrapper(wcsobj)

Expand All @@ -286,8 +279,6 @@ def test_high_level_wrapper(wcsobj, request):


def test_stokes_wrapper(gwcs_stokes_lookup):
pytest.importorskip("astropy", minversion="4.0dev0")

hlvl = HighLevelWCSWrapper(gwcs_stokes_lookup)

pixel_input = [0, 1, 2, 3]
Expand Down
4 changes: 2 additions & 2 deletions gwcs/tests/test_api_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def test_celestial_slice(gwcs_3d_galactic_spectral):
assert_allclose(wcs.pixel_to_world_values(39, 44), (10.24, 20, 25))
assert_allclose(wcs.array_index_to_world_values(44, 39), (10.24, 20, 25))

assert_allclose(wcs.world_to_pixel_values(12.4, 20, 25), (39., 44.))
assert_equal(wcs.world_to_array_index_values(12.4, 20, 25), (44, 39))
assert_allclose(wcs.world_to_pixel_values(10.24, 20, 25), (39., 44.))
assert_equal(wcs.world_to_array_index_values(10.24, 20, 25), (44, 39))

assert_equal(wcs.pixel_bounds, [(-2, 45), (5, 50)])

Expand Down
87 changes: 87 additions & 0 deletions gwcs/tests/test_bounding_box.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy as np
from numpy.testing import assert_array_equal, assert_allclose

import pytest


x = [-1, 2, 4, 13]
y = [np.nan, np.nan, 4, np.nan]
y1 = [np.nan, np.nan, 4, np.nan]


@pytest.mark.parametrize((("input", "output")), [((2, 4), (2, 4)),
((100, 200), (np.nan, np.nan)),
((x, x),(y, y))
])
def test_2d_spatial(gwcs_2d_spatial_shift, input, output):
w = gwcs_2d_spatial_shift
w.bounding_box = ((-.5, 21), (4, 12))

assert_array_equal(w.invert(*w(*input)), output)
assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output)
assert_array_equal(w.world_to_pixel(w.pixel_to_world(*input)), output)


@pytest.mark.parametrize((("input", "output")), [((2, 4), (2, 4)),
((100, 200), (np.nan, np.nan)),
((x, x), (y, y))
])
def test_2d_spatial_coordinate(gwcs_2d_quantity_shift, input, output):
w = gwcs_2d_quantity_shift
w.bounding_box = ((-.5, 21), (4, 12))

assert_array_equal(w.invert(*w(*input)), output)
assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output)
assert_array_equal(w.world_to_pixel(*w.pixel_to_world(*input)), output)


@pytest.mark.parametrize((("input", "output")), [((2, 4), (2, 4)),
((100, 200), (np.nan, np.nan)),
((x, x), (y, y))
])
def test_2d_spatial_coordinate_reordered(gwcs_2d_spatial_reordered, input, output):
w = gwcs_2d_spatial_reordered
w.bounding_box = ((-.5, 21), (4, 12))

assert_array_equal(w.invert(*w(*input)), output)
assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output)
assert_array_equal(w.world_to_pixel(w.pixel_to_world(*input)), output)


@pytest.mark.parametrize((("input", "output")), [(2, 2),
((10, 200), (10, np.nan)),
(x, (np.nan, 2, 4, 13))
])
def test_1d_freq(gwcs_1d_freq, input, output):
w = gwcs_1d_freq
w.bounding_box = (-.5, 21)
print(f"input {input}, {output}")
assert_array_equal(w.invert(w(input)), output)
assert_array_equal(w.world_to_pixel_values(w.pixel_to_world_values(input)), output)
assert_array_equal(w.world_to_pixel(w.pixel_to_world(input)), output)


@pytest.mark.parametrize((("input", "output")), [((2, 4, 5), (2, 4, 5)),
((100, 200, 5), (np.nan, np.nan, np.nan)),
((x, x, x), (y1, y1, y1))
])
def test_3d_spatial_wave(gwcs_3d_spatial_wave, input, output):
w = gwcs_3d_spatial_wave
w.bounding_box = ((-.5, 21), (4, 12), (3, 21))

assert_array_equal(w.invert(*w(*input)), output)
assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output)
assert_array_equal(w.world_to_pixel(*w.pixel_to_world(*input)), output)


@pytest.mark.parametrize((("input", "output")), [((1, 2, 3, 4), (1., 2., 3., 4.)),
((100, 3, 3, 3), (np.nan, 3, 3, 3)),
((x, x, x, x), [[np.nan, 2., 4., 13.],
[np.nan, 2., 4., 13.],
[np.nan, 2., 4., 13.],
[np.nan, 2., 4., np.nan]])
])
def test_gwcs_spec_cel_time_4d(gwcs_spec_cel_time_4d, input, output):
w = gwcs_spec_cel_time_4d

assert_allclose(w.invert(*w(*input, with_bounding_box=False)), output, atol=1e-8)
6 changes: 3 additions & 3 deletions gwcs/tests/test_coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def test_temporal_relative():
assert a[1] == Time("2018-01-01T00:00:00") + 20 * u.s


@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher")
#@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher")
def test_temporal_absolute():
t = cf.TemporalFrame(reference_frame=Time([], format='isot'))
assert t.coordinates("2018-01-01T00:00:00") == Time("2018-01-01T00:00:00")
Expand Down Expand Up @@ -240,7 +240,7 @@ def test_coordinate_to_quantity_spectral(inp):
(Time("2011-01-01T00:00:10"),),
(10 * u.s,)
])
@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.")
#@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.")
def test_coordinate_to_quantity_temporal(inp):
temp = cf.TemporalFrame(reference_frame=Time("2011-01-01T00:00:00"), unit=u.s)

Expand Down Expand Up @@ -325,7 +325,7 @@ def test_coordinate_to_quantity_frame_2d():
assert_quantity_allclose(output, exp)


@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.")
#@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.")
def test_coordinate_to_quantity_error():
frame = cf.Frame2D(unit=(u.one, u.arcsec))
with pytest.raises(ValueError):
Expand Down
6 changes: 6 additions & 0 deletions gwcs/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from astropy import units as u
from astropy import coordinates as coord
from astropy.modeling import models
from astropy import table

from astropy.tests.helper import assert_quantity_allclose
import pytest
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -104,6 +106,10 @@ def test_isnumerical():
assert gwutils.isnumerical(np.array(0, dtype='>f8'))
assert gwutils.isnumerical(np.array(0, dtype='>i4'))

# check a table column
t = table.Table(data=[[1,2,3], [4,5,6]], names=['x', 'y'])
assert not gwutils.isnumerical(t['x'])


def test_get_values():
args = 2 * u.cm
Expand Down
4 changes: 4 additions & 0 deletions gwcs/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,7 @@ def test_iter_inv():
*w(x, y),
adaptive=True,
detect_divergence=True,
tolerance=1e-4, maxiter=50,
quiet=False
)
assert np.allclose((x, y), (xp, yp))
Expand All @@ -1217,6 +1218,7 @@ def test_iter_inv():
xp, yp = w.numerical_inverse(
*w(x, y),
adaptive=True,
tolerance=1e-5, maxiter=50,
detect_divergence=False,
quiet=False
)
Expand Down Expand Up @@ -1251,6 +1253,7 @@ def test_iter_inv():
xp, yp = w.numerical_inverse(
*w(x, y, with_bounding_box=False),
adaptive=False,
tolerance=1e-5, maxiter=50,
detect_divergence=True,
quiet=False,
with_bounding_box=False
Expand All @@ -1264,6 +1267,7 @@ def test_iter_inv():
xp, yp = w.numerical_inverse(
*w(x, y, with_bounding_box=False),
adaptive=False,
tolerance=1e-5, maxiter=50,
detect_divergence=True,
quiet=False,
with_bounding_box=False
Expand Down
11 changes: 5 additions & 6 deletions gwcs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from astropy import coordinates as coords
from astropy import units as u
from astropy.time import Time, TimeDelta
from astropy import table
from astropy.wcs import Celprm


Expand Down Expand Up @@ -470,14 +471,12 @@ def isnumerical(val):
Determine if a value is numerical (number or np.array of numbers).
"""
isnum = True
if isinstance(val, coords.SkyCoord):
isnum = False
elif isinstance(val, u.Quantity):
isnum = False
elif isinstance(val, (Time, TimeDelta)):
astropy_types=(coords.SkyCoord, u.Quantity, Time, TimeDelta, table.Column, table.Row)
if isinstance(val, astropy_types):
isnum = False
elif (isinstance(val, np.ndarray)
and not np.issubdtype(val.dtype, np.floating)
and not np.issubdtype(val.dtype, np.integer)):
and not np.issubdtype(val.dtype, np.integer)
):
isnum = False
return isnum
Loading
Loading