diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index c285a7b1..967dcc92 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -7,7 +7,8 @@ import numpy as np from numpy.testing import assert_allclose, assert_equal -from astropy.modeling import models +from astropy.modeling import models, bind_compound_bounding_box +from astropy.modeling.bounding_box import ModelBoundingBox from astropy import coordinates as coord from astropy.io import fits from astropy import units as u @@ -391,6 +392,70 @@ def test_grid_from_bounding_box_step(): with pytest.raises(ValueError): grid_from_bounding_box(bb, step=(1, 2, 1)) +def test_grid_from_model_bounding_box(): + bbox = ((-1, 1), (0, 1)) + # Truth grid + grid_truth = grid_from_bounding_box(bbox) + + # Create a bounding box + model = models.Const2D() & models.Const1D() + model.inputs = ("x", "y", "slit_name") + model.bounding_box = ModelBoundingBox( + { + "x": bbox[0], + "y": bbox[1], + }, + model=model, + ignored=["slit_name"], + order="F", + ) + grid = grid_from_bounding_box(model.bounding_box) + + assert np.all(grid == grid_truth) + + # Handle the C order case + model.inputs = ("y", "x", "slit_name") + model.bounding_box = ModelBoundingBox( + { + "x": bbox[0], + "y": bbox[1], + }, + model=model, + ignored=["slit_name"], + order="C", + ) + grid = grid_from_bounding_box(model.bounding_box) + + assert np.all(grid == grid_truth) + + +def test_grid_from_compound_bounding_box(): + bbox = ((-1, 1), (0, 1)) + # Truth grid + grid_truth = grid_from_bounding_box(bbox) + + # Create a compound bounding box + model = models.Const2D() & models.Const1D() + model.inputs = ("x", "y", "slit_name") + bind_compound_bounding_box( + model, + { + (200,) : { + "x": bbox[0], + "y": bbox[1], + }, + (300,) :{ + "x": (-2, 2), + "y": (0, 2), + } + }, + [("slit_name",)], + order="F", + ) + grid = grid_from_bounding_box(model.bounding_box, selector=(200,)) + + assert np.all(grid == grid_truth) + def test_wcs_from_points(): np.random.seed(0) diff --git a/gwcs/wcstools.py b/gwcs/wcstools.py index 20987e75..be62ec7d 100644 --- a/gwcs/wcstools.py +++ b/gwcs/wcstools.py @@ -1,10 +1,12 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst +from codecs import BOM import functools import warnings import numpy as np from astropy.modeling.core import Model from astropy.modeling import projections from astropy.modeling import models, fitting +from astropy.modeling.bounding_box import CompoundBoundingBox, ModelBoundingBox from astropy import coordinates as coord from astropy import units as u @@ -139,7 +141,7 @@ def _frame2D_transform(fiducial, **kwargs): } -def grid_from_bounding_box(bounding_box, step=1, center=True): +def grid_from_bounding_box(bounding_box, step=1, center=True, selector=None): """ Create a grid of input points from the WCS bounding_box. @@ -151,11 +153,14 @@ def grid_from_bounding_box(bounding_box, step=1, center=True): Parameters ---------- - bounding_box : tuple + bounding_box : tuple | ~astropy.modeling.bounding_box.ModelBoundingBox | ~astropy.modeling.bounding_box.CompoundBoundingBox The bounding_box of a WCS object, `~gwcs.wcs.WCS.bounding_box`. step : scalar or tuple Step size for grid in each dimension. Scalar applies to all dimensions. center : bool + selector : tuple | None + If selector is set then it must be a selector tuple and bounding_box must + be a CompoundBoundingBox. The bounding_box is in order of X, Y [, Z] and the output will be in the same order. @@ -187,6 +192,30 @@ def grid_from_bounding_box(bounding_box, step=1, center=True): """ def _bbox_to_pixel(bbox): return (np.floor(bbox[0] + 0.5), np.ceil(bbox[1] - 0.5)) + + if selector is not None and not isinstance(bounding_box, CompoundBoundingBox): + raise ValueError("Cannot use selector with a non-CompoundBoundingBox") + + if isinstance(bounding_box, CompoundBoundingBox): + if selector is None: + raise ValueError("selector must be set when bounding_box is a CompoundBoundingBox") + + bounding_box = bounding_box[selector] + + if isinstance(bounding_box, ModelBoundingBox): + input_names = bounding_box.model.inputs + + # Reorder the bounding box to match the order of the inputs + if bounding_box.order == "C": + input_names = input_names[::-1] + + # Get tuple of tuples of the bounding box values + bounding_box = tuple( + tuple(bounding_box[name]) + for name in input_names + if name not in bounding_box.ignored_inputs + ) + # 1D case if np.isscalar(bounding_box[0]): nd = 1