Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing RGB xarray.DataArray images into grdimage #2590

Merged
merged 12 commits into from
Aug 8, 2023
35 changes: 22 additions & 13 deletions pygmt/src/grdimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pygmt.clib import Session
from pygmt.helpers import (
GMTTempFile,
build_arg_string,
data_kind,
fmt_docstring,
Expand Down Expand Up @@ -179,17 +180,25 @@ def grdimage(self, grid, **kwargs):
>>> fig.show()
"""
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access
with Session() as lib:
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
with contextlib.ExitStack() as stack:
# shading using an xr.DataArray
if kwargs.get("I") is not None and data_kind(kwargs["I"]) == "grid":
shading_context = lib.virtualfile_from_data(
check_kind="raster", data=kwargs["I"]
)
kwargs["I"] = stack.enter_context(shading_context)

fname = stack.enter_context(file_context)
lib.call_module(
module="grdimage", args=build_arg_string(kwargs, infile=fname)
)
with GMTTempFile(suffix=".tif") as tmpfile:
if hasattr(grid, "dims") and len(grid.dims) == 3:
grid.rio.to_raster(raster_path=tmpfile.name)
_grid = tmpfile.name
else:
_grid = grid
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should maybe put this logic in lib.virtualfile_from_data, since we'll need to reuse it for grdview, grdcut, and possibly other GMT modules that work with GMT_IMAGE.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, created a new tempfile_from_image function at 89da916 (similar to tempfile_from_geojson). This can be replaced with lib.virtualfile_from_image once that is implemented.


with Session() as lib:
file_context = lib.virtualfile_from_data(check_kind="raster", data=_grid)
with contextlib.ExitStack() as stack:
# shading using an xr.DataArray
if kwargs.get("I") is not None and data_kind(kwargs["I"]) == "grid":
shading_context = lib.virtualfile_from_data(
check_kind="raster", data=kwargs["I"]
)
kwargs["I"] = stack.enter_context(shading_context)

fname = stack.enter_context(file_context)
lib.call_module(
module="grdimage", args=build_arg_string(kwargs, infile=fname)
)
4 changes: 4 additions & 0 deletions pygmt/tests/baseline/test_grdimage_image.png.dvc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
outs:
- md5: 2e919645d5af956ec4f8aa054a86a70a
size: 110214
path: test_grdimage_image.png
63 changes: 63 additions & 0 deletions pygmt/tests/test_grdimage_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Test Figure.grdimage on 3-band RGB images.
"""
import numpy as np
import pandas as pd
import pytest
import xarray as xr
from pygmt import Figure, which

rasterio = pytest.importorskip("rasterio")
rioxarray = pytest.importorskip("rioxarray")


@pytest.fixture(scope="module", name="xr_image")
def fixture_xr_image():
"""
Load the image data from Blue Marble as an xarray.DataArray with shape
{"band": 3, "y": 180, "x": 360}.
"""
geotiff = which(fname="@earth_day_01d_p", download="c")
with rioxarray.open_rasterio(filename=geotiff) as rda:
if len(rda.band) == 1:
with rasterio.open(fp=geotiff) as src:
df_colormap = pd.DataFrame.from_dict(
data=src.colormap(1), orient="index"
)
array = src.read()

red = np.vectorize(df_colormap[0].get)(array)
green = np.vectorize(df_colormap[1].get)(array)
blue = np.vectorize(df_colormap[2].get)(array)
# alpha = np.vectorize(df_colormap[3].get)(array)

rda.data = red
da_red = rda.astype(dtype=np.uint8).copy()
rda.data = green
da_green = rda.astype(dtype=np.uint8).copy()
rda.data = blue
da_blue = rda.astype(dtype=np.uint8).copy()

xr_image = xr.concat(objs=[da_red, da_green, da_blue], dim="band")
assert xr_image.sizes == {"band": 3, "y": 180, "x": 360}
return xr_image
Comment on lines +20 to +43
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should replace this with the load_blue_marble function from #2235 once done. Ideally, the rioxarray.open_rasterio function would be able to load this 3-band GeoTIFF from colorinterp directly without the hacky logic here.



@pytest.mark.mpl_image_compare
def test_grdimage_image():
"""
Plot a 3-band RGB image using file input.
"""
fig = Figure()
fig.grdimage(grid="@earth_day_01d")
return fig


@pytest.mark.mpl_image_compare(filename="test_grdimage_image.png")
def test_grdimage_image_dataarray(xr_image):
"""
Plot a 3-band RGB image using xarray.DataArray input.
"""
fig = Figure()
fig.grdimage(grid=xr_image)
return fig