Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing RGB xarray.DataArray images into grdimage #2590

Merged
merged 12 commits into from
Aug 8, 2023
22 changes: 19 additions & 3 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ctypes as ctp
import pathlib
import sys
import warnings
from contextlib import contextmanager, nullcontext

import numpy as np
Expand All @@ -26,7 +27,12 @@
GMTInvalidInput,
GMTVersionError,
)
from pygmt.helpers import data_kind, fmt_docstring, tempfile_from_geojson
from pygmt.helpers import (
data_kind,
fmt_docstring,
tempfile_from_geojson,
tempfile_from_image,
)

FAMILIES = [
"GMT_IS_DATASET", # Entity is a data table
Expand Down Expand Up @@ -1540,7 +1546,7 @@ def virtualfile_from_data(
if check_kind:
valid_kinds = ("file", "arg") if required_data is False else ("file",)
if check_kind == "raster":
valid_kinds += ("grid",)
valid_kinds += ("grid", "image")
elif check_kind == "vector":
valid_kinds += ("matrix", "vectors", "geojson")
if kind not in valid_kinds:
Expand All @@ -1554,6 +1560,7 @@ def virtualfile_from_data(
"arg": nullcontext,
"geojson": tempfile_from_geojson,
"grid": self.virtualfile_from_grid,
"image": tempfile_from_image,
# Note: virtualfile_from_matrix is not used because a matrix can be
# converted to vectors instead, and using vectors allows for better
# handling of string type inputs (e.g. for datetime data types)
Expand All @@ -1562,7 +1569,16 @@ def virtualfile_from_data(
}[kind]

# Ensure the data is an iterable (Python list or tuple)
if kind in ("geojson", "grid", "file", "arg"):
if kind in ("geojson", "grid", "image", "file", "arg"):
if kind == "image" and data.dtype != "uint8":
msg = (
f"Input image has dtype: {data.dtype} which is unsupported, "
"and may result in an incorrect output. Please recast image "
"to a uint8 dtype and/or scale to 0-255 range, e.g. "
"using a histogram equalization function like "
"skimage.exposure.equalize_hist."
)
warnings.warn(message=msg, category=RuntimeWarning, stacklevel=2)
Comment on lines +1573 to +1581
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a look at grdhisteq, but it doesn't appear to support 3-band image inputs, only 1-band grids, so suggesting to use skimage.exposure.equalize_hist instead. Could probably raise a feature request to upstream GMT to let grdhisteq support this too?

_data = (data,) if not isinstance(data, pathlib.PurePath) else (str(data),)
elif kind == "vectors":
_data = [np.atleast_1d(x), np.atleast_1d(y)]
Expand Down
7 changes: 6 additions & 1 deletion pygmt/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
kwargs_to_strings,
use_alias,
)
from pygmt.helpers.tempfile import GMTTempFile, tempfile_from_geojson, unique_name
from pygmt.helpers.tempfile import (
GMTTempFile,
tempfile_from_geojson,
tempfile_from_image,
unique_name,
)
from pygmt.helpers.utils import (
args_in_kwargs,
build_arg_string,
Expand Down
31 changes: 31 additions & 0 deletions pygmt/helpers/tempfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,34 @@ def tempfile_from_geojson(geojson):
geoseries.to_file(**ogrgmt_kwargs)

yield tmpfile.name


@contextmanager
def tempfile_from_image(image):
"""
Saves a 3-band :class:`xarray.DataArray` to a temporary GeoTIFF file via
rioxarray.

Parameters
----------
image : xarray.DataArray
An xarray.DataArray with three dimensions, having a shape like
(3, Y, X).

Yields
------
tmpfilename : str
A temporary GeoTIFF file holding the image data. E.g. '1a2b3c4d5.tif'.
"""
with GMTTempFile(suffix=".tif") as tmpfile:
os.remove(tmpfile.name) # ensure file is deleted first
try:
image.rio.to_raster(raster_path=tmpfile.name)
except AttributeError as e: # object has no attribute 'rio'
raise ImportError(
"Package `rioxarray` is required to be installed to use this function. "
"Please use `python -m pip install rioxarray` or "
"`mamba install -c conda-forge rioxarray` "
"to install the package."
) from e
yield tmpfile.name
8 changes: 5 additions & 3 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data
Returns
-------
kind : str
One of ``'arg'``, ``'file'``, ``'grid'``, ``'geojson'``, ``'matrix'``,
or ``'vectors'``.
One of ``'arg'``, ``'file'``, ``'grid'``, ``image``, ``'geojson'``,
``'matrix'``, or ``'vectors'``.

Examples
--------
Expand All @@ -166,14 +166,16 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data
'arg'
>>> data_kind(data=xr.DataArray(np.random.rand(4, 3)))
'grid'
>>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5)))
'image'
"""
# determine the data kind
if isinstance(data, (str, pathlib.PurePath)):
kind = "file"
elif isinstance(data, (bool, int, float)) or (data is None and not required_data):
kind = "arg"
elif isinstance(data, xr.DataArray):
kind = "grid"
kind = "image" if len(data.dims) == 3 else "grid"
elif hasattr(data, "__geo_interface__"):
# geo-like Python object that implements ``__geo_interface__``
# (geopandas.GeoDataFrame or shapely.geometry)
Expand Down
27 changes: 12 additions & 15 deletions pygmt/src/grdimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def grdimage(self, grid, **kwargs):
instructions to derive intensities from the input data grid. Values outside
this range will be clipped. Such intensity files can be created from the
grid using :func:`pygmt.grdgradient` and, optionally, modified by
:gmt-docs:`grdmath.html` or :class:`pygmt.grdhisteq`. If GMT is built
with GDAL support, ``grid`` can be an image file (geo-referenced or not).
In this case the image can optionally be illuminated with the file
provided via the ``shading`` parameter. Here, if image has no coordinates
then those of the intensity file will be used.
:gmt-docs:`grdmath.html` or :class:`pygmt.grdhisteq`. Alternatively, pass
*image* which can be an image file (geo-referenced or not). In this case
the image can optionally be illuminated with the file provided via the
``shading`` parameter. Here, if image has no coordinates then those of the
intensity file will be used.

When using map projections, the grid is first resampled on a new
rectangular grid with the same dimensions. Higher resolution images can
Expand Down Expand Up @@ -74,10 +74,7 @@ def grdimage(self, grid, **kwargs):
:gmt-docs:`grdimage.html#grid-file-formats`).
img_out : str
*out_img*\[=\ *driver*].
Save an image in a raster format instead of PostScript. Use
extension .ppm for a Portable Pixel Map format which is the only
raster format GMT can natively write. For GMT installations
configured with GDAL support there are more choices: Append
Save an image in a raster format instead of PostScript. Append
*out_img* to select the image file name and extension. If the
extension is one of .bmp, .gif, .jpg, .png, or .tif then no driver
information is required. For other output formats you must append
Expand Down Expand Up @@ -131,8 +128,8 @@ def grdimage(self, grid, **kwargs):
:func:`pygmt.grdgradient` separately first. If we should derive
intensities from another file than grid, specify the file with
suitable modifiers [Default is no illumination]. **Note**: If the
input data is an *image* then an *intensfile* or constant *intensity*
must be provided.
input data represent an *image* then an *intensfile* or constant
*intensity* must be provided.
{projection}
monochrome : bool
Force conversion to monochrome image using the (television) YIQ
Expand All @@ -144,10 +141,9 @@ def grdimage(self, grid, **kwargs):
[**+z**\ *value*][*color*]
Make grid nodes with z = NaN transparent, using the color-masking
feature in PostScript Level 3 (the PS device must support PS Level
3). If the input is a grid, use **+z** with a *value* to select
another grid value than NaN. If the input is instead an image,
append an alternate *color* to select another pixel value to be
transparent [Default is ``"black"``].
3). If the input is a grid, use **+z** to select another grid value
than NaN. If input is instead an image, append an alternate *color* to
select another pixel value to be transparent [Default is ``"black"``].
{region}
{verbose}
{panel}
Expand All @@ -171,6 +167,7 @@ def grdimage(self, grid, **kwargs):
>>> fig.show()
"""
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access

with Session() as lib:
with lib.virtualfile_from_data(
check_kind="raster", data=grid
Expand Down
16 changes: 5 additions & 11 deletions pygmt/src/tilemap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,7 @@
"""
from pygmt.clib import Session
from pygmt.datasets.tile_map import load_tile_map
from pygmt.helpers import (
GMTTempFile,
build_arg_string,
fmt_docstring,
kwargs_to_strings,
use_alias,
)
from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias

try:
import rioxarray
Expand Down Expand Up @@ -148,9 +142,9 @@ def tilemap(
if kwargs.get("N") in [None, False]:
kwargs["R"] = "/".join(str(coordinate) for coordinate in region)

with GMTTempFile(suffix=".tif") as tmpfile:
raster.rio.to_raster(raster_path=tmpfile.name)
with Session() as lib:
with Session() as lib:
file_context = lib.virtualfile_from_data(check_kind="raster", data=raster)
with file_context as infile:
lib.call_module(
module="grdimage", args=build_arg_string(kwargs, infile=tmpfile.name)
module="grdimage", args=build_arg_string(kwargs, infile=infile)
)
4 changes: 4 additions & 0 deletions pygmt/tests/baseline/test_grdimage_image.png.dvc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
outs:
- md5: 2e919645d5af956ec4f8aa054a86a70a
size: 110214
path: test_grdimage_image.png
79 changes: 79 additions & 0 deletions pygmt/tests/test_grdimage_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Test Figure.grdimage on 3-band RGB images.
"""
import numpy as np
import pandas as pd
import pytest
import xarray as xr
from pygmt import Figure, which

rasterio = pytest.importorskip("rasterio")
rioxarray = pytest.importorskip("rioxarray")


@pytest.fixture(scope="module", name="xr_image")
def fixture_xr_image():
"""
Load the image data from Blue Marble as an xarray.DataArray with shape
{"band": 3, "y": 180, "x": 360}.
"""
geotiff = which(fname="@earth_day_01d_p", download="c")
with rioxarray.open_rasterio(filename=geotiff) as rda:
if len(rda.band) == 1:
with rasterio.open(fp=geotiff) as src:
df_colormap = pd.DataFrame.from_dict(
data=src.colormap(1), orient="index"
)
array = src.read()

red = np.vectorize(df_colormap[0].get)(array)
green = np.vectorize(df_colormap[1].get)(array)
blue = np.vectorize(df_colormap[2].get)(array)
# alpha = np.vectorize(df_colormap[3].get)(array)

rda.data = red
da_red = rda.astype(dtype=np.uint8).copy()
rda.data = green
da_green = rda.astype(dtype=np.uint8).copy()
rda.data = blue
da_blue = rda.astype(dtype=np.uint8).copy()

xr_image = xr.concat(objs=[da_red, da_green, da_blue], dim="band")
assert xr_image.sizes == {"band": 3, "y": 180, "x": 360}
return xr_image
Comment on lines +20 to +43
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should replace this with the load_blue_marble function from #2235 once done. Ideally, the rioxarray.open_rasterio function would be able to load this 3-band GeoTIFF from colorinterp directly without the hacky logic here.



@pytest.mark.mpl_image_compare
def test_grdimage_image():
"""
Plot a 3-band RGB image using file input.
"""
fig = Figure()
fig.grdimage(grid="@earth_day_01d")
return fig


@pytest.mark.mpl_image_compare(filename="test_grdimage_image.png")
def test_grdimage_image_dataarray(xr_image):
"""
Plot a 3-band RGB image using xarray.DataArray input.
"""
fig = Figure()
fig.grdimage(grid=xr_image)
return fig


@pytest.mark.parametrize(
"dtype",
["int8", "uint16", "int16", "uint32", "int32", "float32", "float64"],
)
def test_grdimage_image_dataarray_unsupported_dtype(dtype, xr_image):
"""
Plot a 3-band RGB image using xarray.DataArray input, with an unsupported
data type.
"""
fig = Figure()
image = xr_image.astype(dtype=dtype)
with pytest.warns(expected_warning=RuntimeWarning) as record:
fig.grdimage(grid=image)
assert len(record) == 1
Loading