diff --git a/Tests/images/flower.jxl b/Tests/images/flower.jxl new file mode 100644 index 00000000000..aeb8ae79165 Binary files /dev/null and b/Tests/images/flower.jxl differ diff --git a/Tests/images/flower2.jxl b/Tests/images/flower2.jxl new file mode 100644 index 00000000000..30d45a13d16 Binary files /dev/null and b/Tests/images/flower2.jxl differ diff --git a/Tests/images/hopper.jxl b/Tests/images/hopper.jxl new file mode 100644 index 00000000000..d89d3c267fe Binary files /dev/null and b/Tests/images/hopper.jxl differ diff --git a/Tests/images/hopper_jxl_bits.ppm b/Tests/images/hopper_jxl_bits.ppm new file mode 100644 index 00000000000..881aca33369 Binary files /dev/null and b/Tests/images/hopper_jxl_bits.ppm differ diff --git a/Tests/images/iss634.jxl b/Tests/images/iss634.jxl new file mode 100644 index 00000000000..99c2cf03633 Binary files /dev/null and b/Tests/images/iss634.jxl differ diff --git a/Tests/images/jxl/16bit_subcutaneous.cropped.jxl b/Tests/images/jxl/16bit_subcutaneous.cropped.jxl new file mode 100644 index 00000000000..eae30759603 Binary files /dev/null and b/Tests/images/jxl/16bit_subcutaneous.cropped.jxl differ diff --git a/Tests/images/jxl/16bit_subcutaneous.cropped.png b/Tests/images/jxl/16bit_subcutaneous.cropped.png new file mode 100644 index 00000000000..b337f7bddbe Binary files /dev/null and b/Tests/images/jxl/16bit_subcutaneous.cropped.png differ diff --git a/Tests/images/jxl/traffic_light.gif b/Tests/images/jxl/traffic_light.gif new file mode 100644 index 00000000000..4f7ecfdbcd7 Binary files /dev/null and b/Tests/images/jxl/traffic_light.gif differ diff --git a/Tests/images/jxl/traffic_light.jxl b/Tests/images/jxl/traffic_light.jxl new file mode 100644 index 00000000000..c777e3bd618 Binary files /dev/null and b/Tests/images/jxl/traffic_light.jxl differ diff --git a/Tests/images/transparent.jxl b/Tests/images/transparent.jxl new file mode 100644 index 00000000000..cea19bb6cfa Binary files /dev/null and b/Tests/images/transparent.jxl differ diff --git a/Tests/test_file_jxl.py b/Tests/test_file_jxl.py new file mode 100644 index 00000000000..c1b730a79e7 --- /dev/null +++ b/Tests/test_file_jxl.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import re + +import pytest + +from PIL import Image, JpegXlImagePlugin, features + +from .helper import ( + assert_image_similar_tofile, + skip_unless_feature, +) + +try: + from PIL import _jpegxl + + HAVE_JPEGXL = True +except ImportError: + HAVE_JPEGXL = False + +# cjxl v0.9.2 41b8cdab +# hopper.jxl: cjxl hopper.png hopper.jxl -q 75 -e 8 +# 16_bit_binary.jxl: cjxl 16_bit_binary.pgm 16_bit_binary.jxl -q 100 -e 9 + + +class TestUnsupportedJpegXl: + def test_unsupported(self) -> None: + if HAVE_JPEGXL: + JpegXlImagePlugin.SUPPORTED = False + + file_path = "Tests/images/hopper.jxl" + with pytest.raises(OSError): + with Image.open(file_path): + pass + + if HAVE_JPEGXL: + JpegXlImagePlugin.SUPPORTED = True + + +@skip_unless_feature("jpegxl") +class TestFileJpegXl: + def setup_method(self) -> None: + self.rgb_mode = "RGB" + self.i16_mode = "I;16" + + def test_version(self) -> None: + _jpegxl.JpegXlDecoderVersion() + assert re.search(r"\d+\.\d+\.\d+$", features.version_module("jpegxl")) + + def test_read_rgb(self) -> None: + """ + Can we read a RGB mode Jpeg XL file without error? + Does it have the bits we expect? + """ + + with Image.open("Tests/images/hopper.jxl") as image: + assert image.mode == self.rgb_mode + assert image.size == (128, 128) + assert image.format == "JPEG XL" + image.load() + image.getdata() + + # generated with: + # djxl hopper.jxl hopper_jxl_bits.ppm + assert_image_similar_tofile(image, "Tests/images/hopper_jxl_bits.ppm", 1.0) + + def test_read_i16(self) -> None: + """ + Can we read 16-bit Grayscale Jpeg XL image? + """ + + with Image.open("Tests/images/jxl/16bit_subcutaneous.cropped.jxl") as image: + assert image.mode == self.i16_mode + assert image.size == (128, 64) + assert image.format == "JPEG XL" + image.load() + image.getdata() + + assert_image_similar_tofile( + image, "Tests/images/jxl/16bit_subcutaneous.cropped.png", 1.0 + ) + + def test_JpegXlDecode_with_invalid_args(self) -> None: + """ + Calling decoder functions with no arguments should result in an error. + """ + + with pytest.raises(TypeError): + _jpegxl.PILJpegXlDecoder() diff --git a/Tests/test_file_jxl_alpha.py b/Tests/test_file_jxl_alpha.py new file mode 100644 index 00000000000..8c3ab2b7111 --- /dev/null +++ b/Tests/test_file_jxl_alpha.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import pytest + +from PIL import Image + +from .helper import assert_image_similar_tofile + +_jpegxl = pytest.importorskip("PIL._jpegxl", reason="JPEG XL support not installed") + + +def test_read_rgba() -> None: + """ + Can we read an RGBA mode file without error? + Does it have the bits we expect? + """ + + # Generated with `cjxl transparent.png transparent.jxl -q 100 -e 8` + file_path = "Tests/images/transparent.jxl" + with Image.open(file_path) as image: + assert image.mode == "RGBA" + assert image.size == (200, 150) + assert image.format == "JPEG XL" + image.load() + image.getdata() + + image.tobytes() + + assert_image_similar_tofile(image, "Tests/images/transparent.png", 1.0) diff --git a/Tests/test_file_jxl_animated.py b/Tests/test_file_jxl_animated.py new file mode 100644 index 00000000000..758fa79e2d8 --- /dev/null +++ b/Tests/test_file_jxl_animated.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import pytest + +from PIL import Image + +from .helper import ( + assert_image_equal, + skip_unless_feature, +) + +pytestmark = [ + skip_unless_feature("jpegxl"), +] + + +def test_n_frames() -> None: + """Ensure that jxl format sets n_frames and is_animated attributes correctly.""" + + with Image.open("Tests/images/hopper.jxl") as im: + assert im.n_frames == 1 + assert not im.is_animated + + with Image.open("Tests/images/iss634.jxl") as im: + assert im.n_frames == 41 + assert im.is_animated + + +def test_float_duration() -> None: + + with Image.open("Tests/images/iss634.jxl") as im: + im.load() + assert im.info["duration"] == 70 + + +def test_seeking() -> None: + """ + Open an animated jxl file, and then try seeking through frames in reverse-order, + verifying the durations are correct. + """ + + with Image.open("Tests/images/jxl/traffic_light.jxl") as im1: + with Image.open("Tests/images/jxl/traffic_light.gif") as im2: + assert im1.n_frames == im2.n_frames + assert im1.is_animated + + # Traverse frames in reverse, checking timestamps and durations + total_dur = 0 + for frame in reversed(range(im1.n_frames)): + im1.seek(frame) + im1.load() + im2.seek(frame) + im2.load() + + assert_image_equal(im1.convert("RGB"), im2.convert("RGB")) + + total_dur += im1.info["duration"] + assert im1.info["duration"] == im2.info["duration"] + assert im1.info["timestamp"] == im1.info["timestamp"] + assert total_dur == 8000 + + assert im1.tell() == 0 and im2.tell() == 0 + + im1.seek(0) + im1.load() + im2.seek(0) + im2.load() + + +def test_seek_errors() -> None: + with Image.open("Tests/images/iss634.jxl") as im: + with pytest.raises(EOFError): + im.seek(-1) + + with pytest.raises(EOFError): + im.seek(47) diff --git a/Tests/test_file_jxl_metadata.py b/Tests/test_file_jxl_metadata.py new file mode 100644 index 00000000000..f6be96c8062 --- /dev/null +++ b/Tests/test_file_jxl_metadata.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from types import ModuleType + +import pytest + +from PIL import Image + +from .helper import skip_unless_feature + +pytestmark = [ + skip_unless_feature("jpegxl"), +] + +ElementTree: ModuleType | None +try: + from defusedxml import ElementTree +except ImportError: + ElementTree = None + + +# cjxl flower.jpg flower.jxl --lossless_jpeg=0 -q 75 -e 8 + +# >>> from PIL import Image +# >>> with Image.open('Tests/images/flower2.webp') as im: +# >>> with open('/tmp/xmp.xml', 'wb') as f: +# >>> f.write(im.info['xmp']) +# cjxl flower2.jpg flower2.jxl --lossless_jpeg=0 -q 75 -e 8 -x xmp=/tmp/xmp.xml + + +def test_read_exif_metadata() -> None: + file_path = "Tests/images/flower.jxl" + with Image.open(file_path) as image: + assert image.format == "JPEG XL" + exif_data = image.info.get("exif", None) + assert exif_data + + exif = image._getexif() + + # Camera make + assert exif[271] == "Canon" + + with Image.open("Tests/images/flower.jpg") as jpeg_image: + expected_exif = jpeg_image.info["exif"] + + # jpeg xl always returns exif without 'Exif\0\0' prefix + assert exif_data == expected_exif[6:] + + +def test_read_exif_metadata_without_prefix() -> None: + with Image.open("Tests/images/flower2.jxl") as im: + # Assert prefix is not present + assert im.info["exif"][:6] != b"Exif\x00\x00" + + exif = im.getexif() + assert exif[305] == "Adobe Photoshop CS6 (Macintosh)" + + +def test_read_icc_profile() -> None: + file_path = "Tests/images/flower2.jxl" + with Image.open(file_path) as image: + assert image.format == "JPEG XL" + assert image.info.get("icc_profile", None) + + icc = image.info["icc_profile"] + + with Image.open("Tests/images/flower2.jxl") as jpeg_image: + expected_icc = jpeg_image.info["icc_profile"] + + assert icc == expected_icc + + +def test_getxmp() -> None: + with Image.open("Tests/images/flower.jxl") as im: + assert "xmp" not in im.info + assert im.getxmp() == {} + + with Image.open("Tests/images/flower2.jxl") as im: + if ElementTree: + assert ( + im.getxmp()["xmpmeta"]["xmptk"] + == "Adobe XMP Core 5.3-c011 66.145661, 2012/02/06-14:56:27 " + ) + else: + with pytest.warns( + UserWarning, + match="XMP data cannot be read without defusedxml dependency", + ): + assert im.getxmp() == {} + + +def test_fix_exif_fail() -> None: + with Image.open("Tests/images/flower2.jxl") as image: + assert image._fix_exif(b"\0\0\0\0") is None + + +def test_read_exif_metadata_empty() -> None: + with Image.open("Tests/images/hopper.jxl") as image: + assert image._getexif() is None diff --git a/Tests/test_jxl_leaks.py b/Tests/test_jxl_leaks.py new file mode 100644 index 00000000000..cec9f152894 --- /dev/null +++ b/Tests/test_jxl_leaks.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from io import BytesIO + +from PIL import Image + +from .helper import PillowLeakTestCase, skip_unless_feature + +TEST_FILE = "Tests/images/hopper.jxl" + + +@skip_unless_feature("jpegxl") +class TestJpegXlLeaks(PillowLeakTestCase): + mem_limit = 6 * 1024 # kb + iterations = 1000 + + def test_leak_load(self) -> None: + with open(TEST_FILE, "rb") as f: + im_data = f.read() + + def core() -> None: + with Image.open(BytesIO(im_data)) as im: + im.load() + + self._test_leak(core) diff --git a/setup.py b/setup.py index 60707083f6e..eeb050ed3ad 100644 --- a/setup.py +++ b/setup.py @@ -297,6 +297,7 @@ class ext_feature: features = [ "zlib", "jpeg", + "jpegxl", "tiff", "freetype", "raqm", @@ -735,6 +736,14 @@ def build_extensions(self) -> None: feature.set("jpeg2000", "openjp2") feature.set("openjpeg_version", ".".join(str(x) for x in best_version)) + if feature.want("jpegxl"): + _dbg("Looking for jpegxl") + if _find_include_file(self, "jxl/encode.h") and _find_include_file( + self, "jxl/decode.h" + ): + if _find_library_file(self, "jxl"): + feature.set("jpegxl", "jxl jxl_threads") + if feature.want("imagequant"): _dbg("Looking for imagequant") if _find_include_file(self, "libimagequant.h"): @@ -818,6 +827,15 @@ def build_extensions(self) -> None: # alternate Windows name. feature.set("lcms", "lcms2_static") + if feature.get("jpegxl"): + # jxl and jxl_threads are required + libs = feature.get("jpegxl").split() + defs = [] + + self._update_extension("PIL._jpegxl", libs, defs) + else: + self._remove_extension("PIL._jpegxl") + if feature.want("webp"): _dbg("Looking for webp") if all( @@ -967,6 +985,7 @@ def summary_report(self, feature: ext_feature) -> None: (feature.get("freetype"), "FREETYPE2"), (feature.get("raqm"), "RAQM (Text shaping)", raqm_extra_info), (feature.get("lcms"), "LITTLECMS2"), + (feature.get("jpegxl"), "JPEG XL"), (feature.get("webp"), "WEBP"), (feature.get("xcb"), "XCB (X protocol)"), ] @@ -1010,6 +1029,7 @@ def debug_build() -> bool: Extension("PIL._imaging", files), Extension("PIL._imagingft", ["src/_imagingft.c"]), Extension("PIL._imagingcms", ["src/_imagingcms.c"]), + Extension("PIL._jpegxl", ["src/_jpegxl.c"]), Extension("PIL._webp", ["src/_webp.c"]), Extension("PIL._imagingtk", ["src/_imagingtk.c", "src/Tk/tkImaging.c"]), Extension("PIL._imagingmath", ["src/_imagingmath.c"]), diff --git a/src/PIL/JpegXlImagePlugin.py b/src/PIL/JpegXlImagePlugin.py new file mode 100644 index 00000000000..4e475bf4a6f --- /dev/null +++ b/src/PIL/JpegXlImagePlugin.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import struct +from io import BytesIO + +from . import Image, ImageFile + +try: + from . import _jpegxl + + SUPPORTED = True +except ImportError: + SUPPORTED = False + + +## Future idea: +## it's not known how many frames does animated image have +## by default, _jxl_decoder_new will iterate over all frames without decoding them +## then libjxl decoder is rewinded and we're ready to decode frame by frame +## if OPEN_COUNTS_FRAMES is False, n_frames will be None until the last frame is decoded +## it only applies to animated jpeg xl images +# OPEN_COUNTS_FRAMES = True + + +def _accept(prefix: bytes) -> bool: + is_jxl = ( + prefix[:2] == b"\xff\x0a" + or prefix[:12] == b"\x00\x00\x00\x0c\x4a\x58\x4c\x20\x0d\x0a\x87\x0a" + ) + if is_jxl and not SUPPORTED: + msg = "image file could not be identified because JXL support not installed" + raise SyntaxError(msg) + return is_jxl + + +class JpegXlImageFile(ImageFile.ImageFile): + format = "JPEG XL" + format_description = "JPEG XL image" + __loaded = 0 + __logical_frame = 0 + + def _open(self) -> None: + self._decoder = _jpegxl.PILJpegXlDecoder(self.fp.read()) + + width, height, mode, has_anim, tps_num, tps_denom, n_loops, n_frames = ( + self._decoder.get_info() + ) + self._size = width, height + self.info["loop"] = n_loops + self.is_animated = has_anim + + self._tps_dur_secs = 1 + self.n_frames: int | None = 1 + if self.is_animated: + self.n_frames = None + if n_frames > 0: + self.n_frames = n_frames + self._tps_dur_secs = tps_num / tps_denom + + # TODO: handle libjxl time codes + self.__timestamp = 0 + + self._mode = mode + self.rawmode = mode + self.tile = [] + + if icc := self._decoder.get_icc(): + self.info["icc_profile"] = icc + if exif := self._decoder.get_exif(): + self.info["exif"] = self._fix_exif(exif) + if xmp := self._decoder.get_xmp(): + self.info["xmp"] = xmp + + self._rewind() + + def _fix_exif(self, exif: bytes) -> bytes | None: + # jpeg xl does some weird shenanigans when storing exif + # it omits first 6 bytes of tiff header but adds 4 byte offset instead + if len(exif) <= 4: + return None + exif_start_offset = struct.unpack(">I", exif[:4])[0] + return exif[exif_start_offset + 4 :] + + def _getexif(self) -> dict[str, str] | None: + if "exif" not in self.info: + return None + return self.getexif()._get_merged_dict() + + def getxmp(self) -> dict[str, str]: + return self._getxmp(self.info["xmp"]) if "xmp" in self.info else {} + + def _get_next(self) -> tuple[bytes, float, float, bool]: + + # Get next frame + next_frame = self._decoder.get_next() + self.__physical_frame += 1 + + # this actually means EOF, errors are raised in _jxl + if next_frame is None: + msg = "failed to decode next frame in JXL file" + raise EOFError(msg) + + data, tps_duration, is_last = next_frame + if is_last and self.n_frames is None: + # libjxl said this frame is the last one + self.n_frames = self.__physical_frame + + # duration in miliseconds + duration = 1000 * tps_duration * (1 / self._tps_dur_secs) + timestamp = self.__timestamp + self.__timestamp += duration + + return data, timestamp, duration, is_last + + def _rewind(self, hard: bool = False) -> None: + if hard: + self._decoder.rewind() + self.__physical_frame = 0 + self.__loaded = -1 + self.__timestamp = 0 + + def _seek_check(self, frame: int) -> bool: + # if image is not animated then only the 0th frame is available + if (not self.is_animated and frame != 0) or ( + self.n_frames is not None and (frame >= self.n_frames or frame < 0) + ): + msg = "attempt to seek outside sequence" + raise EOFError(msg) + + return self.tell() != frame + + def _seek(self, frame: int) -> None: + # print("_seek: phy: {}, fr: {}".format(self.__physical_frame, frame)) + if frame == self.__physical_frame: + return # Nothing to do + if frame < self.__physical_frame: + # also rewind libjxl decoder instance + self._rewind(hard=True) + + while self.__physical_frame < frame: + self._get_next() # Advance to the requested frame + + def seek(self, frame: int) -> None: + if not self._seek_check(frame): + return + + # Set logical frame to requested position + self.__logical_frame = frame + + def load(self): + + if self.__loaded != self.__logical_frame: + self._seek(self.__logical_frame) + + data, timestamp, duration, is_last = self._get_next() + self.info["timestamp"] = timestamp + self.info["duration"] = duration + self.__loaded = self.__logical_frame + + # Set tile + if self.fp and self._exclusive_fp: + self.fp.close() + # this is horribly memory inefficient + # you need probably 2*(raw image plane) bytes of memory + self.fp = BytesIO(data) + self.tile = [("raw", (0, 0) + self.size, 0, self.rawmode)] + + return super().load() + + def load_seek(self, pos: int) -> None: + pass + + def tell(self) -> int: + return self.__logical_frame + + +Image.register_open(JpegXlImageFile.format, JpegXlImageFile, _accept) +Image.register_extension(JpegXlImageFile.format, ".jxl") +Image.register_mime(JpegXlImageFile.format, "image/jxl") diff --git a/src/PIL/__init__.py b/src/PIL/__init__.py index 09546fe6333..68254d36e4a 100644 --- a/src/PIL/__init__.py +++ b/src/PIL/__init__.py @@ -47,6 +47,7 @@ "IptcImagePlugin", "JpegImagePlugin", "Jpeg2KImagePlugin", + "JpegXlImagePlugin", "McIdasImagePlugin", "MicImagePlugin", "MpegImagePlugin", diff --git a/src/PIL/_jpegxl.pyi b/src/PIL/_jpegxl.pyi new file mode 100644 index 00000000000..b0235555dc5 --- /dev/null +++ b/src/PIL/_jpegxl.pyi @@ -0,0 +1,5 @@ +from __future__ import annotations + +from typing import Any + +def __getattr__(name: str) -> Any: ... diff --git a/src/PIL/features.py b/src/PIL/features.py index 24c5ee978b3..4ee9a207df5 100644 --- a/src/PIL/features.py +++ b/src/PIL/features.py @@ -16,6 +16,7 @@ "tkinter": ("PIL._tkinter_finder", "tk_version"), "freetype2": ("PIL._imagingft", "freetype2_version"), "littlecms2": ("PIL._imagingcms", "littlecms_version"), + "jpegxl": ("PIL._jpegxl", "libjxl_version"), "webp": ("PIL._webp", "webpdecoder_version"), } @@ -285,6 +286,7 @@ def pilinfo(out: IO[str] | None = None, supported_formats: bool = True) -> None: ("freetype2", "FREETYPE2"), ("littlecms2", "LITTLECMS2"), ("webp", "WEBP"), + ("jpegxl", "JPEG XL"), ("jpg", "JPEG"), ("jpg_2000", "OPENJPEG (JPEG2000)"), ("zlib", "ZLIB (PNG/ZIP)"), diff --git a/src/_jpegxl.c b/src/_jpegxl.c new file mode 100644 index 00000000000..592db9adca6 --- /dev/null +++ b/src/_jpegxl.c @@ -0,0 +1,663 @@ +#define PY_SSIZE_T_CLEAN +#include +#include +#include "libImaging/Imaging.h" + +#include +#include +#include +#include + +#define _PIL_JXL_CHECK(call_name) \ + if (decp->status != JXL_DEC_SUCCESS) { \ + jxl_call_name = call_name; \ + goto end; \ + } + +void +_pil_jxl_get_pixel_format(JxlPixelFormat *pf, const JxlBasicInfo *bi) { + pf->num_channels = bi->num_color_channels + bi->num_extra_channels; + + if (bi->exponent_bits_per_sample > 0 || bi->alpha_exponent_bits > 0) { + pf->data_type = JXL_TYPE_FLOAT; // not yet supported + } else if (bi->bits_per_sample > 8) { + pf->data_type = JXL_TYPE_UINT16; // not yet supported + } else { + pf->data_type = JXL_TYPE_UINT8; + } + + // this *might* cause some issues on Big-Endian systems + // would be great to test it + pf->endianness = JXL_NATIVE_ENDIAN; + pf->align = 0; +} + +// TODO: floating point mode +char * +_pil_jxl_get_mode(const JxlBasicInfo *bi) { + // 16-bit single channel images are supported + if (bi->bits_per_sample == 16 && bi->num_color_channels == 1 && + bi->alpha_bits == 0 && !bi->alpha_premultiplied) + return "I;16"; + + // PIL doesn't support high bit depth images + // it will throw an exception but that's for your own good + // you wouldn't want to see distorted image + if (bi->bits_per_sample != 8) + return "uns"; + + // image has transparency + if (bi->alpha_bits > 0) { + if (bi->num_color_channels == 3) { + if (bi->alpha_premultiplied) + return "RGBa"; + return "RGBA"; + } + if (bi->num_color_channels == 1) { + if (bi->alpha_premultiplied) + return "La"; + return "LA"; + } + } + + // image has no transparency + if (bi->num_color_channels == 3) + return "RGB"; + if (bi->num_color_channels == 1) + return "L"; + + // could not recognize mode + return NULL; +} + +// Decoder type +typedef struct { + PyObject_HEAD JxlDecoder *decoder; + void *runner; + + uint8_t *jxl_data; // input jxl bitstream + Py_ssize_t jxl_data_len; // length of input jxl bitstream + + uint8_t *outbuf; + Py_ssize_t outbuf_len; + + uint8_t *jxl_icc; + Py_ssize_t jxl_icc_len; + uint8_t *jxl_exif; + Py_ssize_t jxl_exif_len; + uint8_t *jxl_xmp; + Py_ssize_t jxl_xmp_len; + + JxlDecoderStatus status; + JxlBasicInfo basic_info; + JxlPixelFormat pixel_format; + + Py_ssize_t n_frames; + + char *mode; +} PILJpegXlDecoderObject; + +static PyTypeObject PILJpegXlDecoder_Type; + +void +_jxl_decoder_dealloc(PyObject *self) { + PILJpegXlDecoderObject *decp = (PILJpegXlDecoderObject *)self; + + if (decp->jxl_data) { + free(decp->jxl_data); + decp->jxl_data = NULL; + decp->jxl_data_len = 0; + } + if (decp->outbuf) { + free(decp->outbuf); + decp->outbuf = NULL; + decp->outbuf_len = 0; + } + if (decp->jxl_icc) { + free(decp->jxl_icc); + decp->jxl_icc = NULL; + decp->jxl_icc_len = 0; + } + if (decp->jxl_exif) { + free(decp->jxl_exif); + decp->jxl_exif = NULL; + decp->jxl_exif_len = 0; + } + if (decp->jxl_xmp) { + free(decp->jxl_xmp); + decp->jxl_xmp = NULL; + decp->jxl_xmp_len = 0; + } + + if (decp->decoder) { + JxlDecoderDestroy(decp->decoder); + decp->decoder = NULL; + } + + if (decp->runner) { + JxlThreadParallelRunnerDestroy(decp->runner); + decp->runner = NULL; + } +} + +// sets input jxl bitstream loaded into jxl_data +// has to be called after every rewind +void +_jxl_decoder_set_input(PyObject *self) { + PILJpegXlDecoderObject *decp = (PILJpegXlDecoderObject *)self; + + decp->status = + JxlDecoderSetInput(decp->decoder, decp->jxl_data, decp->jxl_data_len); + + // the input contains the whole jxl bitstream so it can be closed + JxlDecoderCloseInput(decp->decoder); +} + +PyObject * +_jxl_decoder_rewind(PyObject *self) { + PILJpegXlDecoderObject *decp = (PILJpegXlDecoderObject *)self; + JxlDecoderRewind(decp->decoder); + Py_RETURN_NONE; +} + +bool +_jxl_decoder_count_frames(PyObject *self) { + PILJpegXlDecoderObject *decp = (PILJpegXlDecoderObject *)self; + + decp->n_frames = 0; + + // count all JXL_DEC_NEED_IMAGE_OUT_BUFFER events + while (decp->status != JXL_DEC_SUCCESS) { + // printf("fetch_frame_count status: %u\n", decp->status); + decp->status = JxlDecoderProcessInput(decp->decoder); + + if (decp->status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + if (JxlDecoderSkipCurrentFrame(decp->decoder) != JXL_DEC_SUCCESS) { + return false; + } + decp->n_frames++; + } + } + + _jxl_decoder_rewind((PyObject *)decp); + + return true; +} + +PyObject * +_jxl_decoder_new(PyObject *self, PyObject *args) { + PyBytesObject *jxl_string; + + PILJpegXlDecoderObject *decp = NULL; + decp = PyObject_New(PILJpegXlDecoderObject, &PILJpegXlDecoder_Type); + decp->mode = NULL; + decp->jxl_data = NULL; + decp->jxl_data_len = 0; + decp->outbuf = NULL; + decp->outbuf_len = 0; + decp->jxl_icc = NULL; + decp->jxl_icc_len = 0; + decp->jxl_exif = NULL; + decp->jxl_exif_len = 0; + decp->jxl_xmp = NULL; + decp->jxl_xmp_len = 0; + decp->n_frames = 0; + + // used for printing more detailed error messages + char *jxl_call_name; + + // parse one argument which is a string with jxl data + if (!PyArg_ParseTuple(args, "S", &jxl_string)) { + return NULL; + } + + // this data needs to be copied to PILJpegXlDecoderObject + // so that input bitstream is preserved across calls + const uint8_t *_tmp_jxl_data; + Py_ssize_t _tmp_jxl_data_len; + + // convert jxl data string to C uint8_t pointer + PyBytes_AsStringAndSize( + (PyObject *)jxl_string, (char **)&_tmp_jxl_data, &_tmp_jxl_data_len + ); + + // here occurs this copying (inefficiency) + decp->jxl_data = malloc(_tmp_jxl_data_len); + memcpy(decp->jxl_data, _tmp_jxl_data, _tmp_jxl_data_len); + decp->jxl_data_len = _tmp_jxl_data_len; + + // printf("%zu\n", decp->jxl_data_len); + + size_t suggested_num_threads = JxlThreadParallelRunnerDefaultNumWorkerThreads(); + decp->runner = JxlThreadParallelRunnerCreate(NULL, suggested_num_threads); + decp->decoder = JxlDecoderCreate(NULL); + + decp->status = JxlDecoderSetParallelRunner( + decp->decoder, JxlThreadParallelRunner, decp->runner + ); + _PIL_JXL_CHECK("JxlDecoderSetParallelRunner") + + decp->status = JxlDecoderSubscribeEvents( + decp->decoder, + JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING | JXL_DEC_FRAME | JXL_DEC_BOX | + JXL_DEC_FULL_IMAGE + ); + _PIL_JXL_CHECK("JxlDecoderSubscribeEvents") + + // tell libjxl to decompress boxes (for example Exif is usually compressed) + decp->status = JxlDecoderSetDecompressBoxes(decp->decoder, JXL_TRUE); + _PIL_JXL_CHECK("JxlDecoderSetDecompressBoxes") + + _jxl_decoder_set_input((PyObject *)decp); + _PIL_JXL_CHECK("JxlDecoderSetInput") + + // decode everything up to the first frame + do { + decp->status = JxlDecoderProcessInput(decp->decoder); + // printf("Status: %d\n", decp->status); + +decoder_loop_skip_process: + + // there was an error at JxlDecoderProcessInput stage + if (decp->status == JXL_DEC_ERROR) { + jxl_call_name = "JxlDecoderProcessInput"; + goto end; + } + + // got basic info + if (decp->status == JXL_DEC_BASIC_INFO) { + decp->status = JxlDecoderGetBasicInfo(decp->decoder, &decp->basic_info); + _PIL_JXL_CHECK("JxlDecoderGetBasicInfo"); + + _pil_jxl_get_pixel_format(&decp->pixel_format, &decp->basic_info); + if (decp->pixel_format.data_type != JXL_TYPE_UINT8 && + decp->pixel_format.data_type != JXL_TYPE_UINT16) { + // only 8 bit integer value images are supported for now + PyErr_SetString( + PyExc_NotImplementedError, "unsupported pixel data type" + ); + goto end_with_custom_error; + } + decp->mode = _pil_jxl_get_mode(&decp->basic_info); + + continue; + } + + // got color encoding + if (decp->status == JXL_DEC_COLOR_ENCODING) { + decp->status = JxlDecoderGetICCProfileSize( + decp->decoder, JXL_COLOR_PROFILE_TARGET_DATA, &decp->jxl_icc_len + ); + _PIL_JXL_CHECK("JxlDecoderGetICCProfileSize"); + + decp->jxl_icc = malloc(decp->jxl_icc_len); + if (!decp->jxl_icc) { + PyErr_SetString(PyExc_OSError, "jxl_icc malloc failed"); + goto end_with_custom_error; + } + + decp->status = JxlDecoderGetColorAsICCProfile( + decp->decoder, + JXL_COLOR_PROFILE_TARGET_DATA, + decp->jxl_icc, + decp->jxl_icc_len + ); + _PIL_JXL_CHECK("JxlDecoderGetColorAsICCProfile"); + + continue; + } + + if (decp->status == JXL_DEC_BOX) { + char btype[4]; + decp->status = JxlDecoderGetBoxType(decp->decoder, btype, JXL_TRUE); + _PIL_JXL_CHECK("JxlDecoderGetBoxType"); + + // printf("found box type: %c%c%c%c\n", btype[0], btype[1], btype[2], + // btype[3]); + + bool is_box_exif, is_box_xmp; + is_box_exif = !memcmp(btype, "Exif", 4); + is_box_xmp = !memcmp(btype, "xml ", 4); + if (!is_box_exif && !is_box_xmp) { + // not exif/xmp box so continue + continue; + } + + size_t cur_compr_box_size; + decp->status = JxlDecoderGetBoxSizeRaw(decp->decoder, &cur_compr_box_size); + _PIL_JXL_CHECK("JxlDecoderGetBoxSizeRaw"); + // printf("Exif/xmp box size: %zu\n", cur_compr_box_size); + + uint8_t *final_jxl_buf = NULL; + Py_ssize_t final_jxl_buf_len = 0; + + // cur_box_size is actually compressed box size + // it will also serve as our chunk size + do { + uint8_t *_new_jxl_buf = + realloc(final_jxl_buf, final_jxl_buf_len + cur_compr_box_size); + if (!_new_jxl_buf) { + PyErr_SetString(PyExc_OSError, "failed to allocate final_jxl_buf"); + goto end; + } + final_jxl_buf = _new_jxl_buf; + + decp->status = JxlDecoderSetBoxBuffer( + decp->decoder, final_jxl_buf + final_jxl_buf_len, cur_compr_box_size + ); + _PIL_JXL_CHECK("JxlDecoderSetBoxBuffer"); + + decp->status = JxlDecoderProcessInput(decp->decoder); + + size_t remaining = JxlDecoderReleaseBoxBuffer(decp->decoder); + // printf("boxes status: %d, remaining: %zu\n", decp->status, + // remaining); + final_jxl_buf_len += (cur_compr_box_size - remaining); + } while (decp->status == JXL_DEC_BOX_NEED_MORE_OUTPUT); + + if (is_box_exif) { + decp->jxl_exif = final_jxl_buf; + decp->jxl_exif_len = final_jxl_buf_len; + } else { + decp->jxl_xmp = final_jxl_buf; + decp->jxl_xmp_len = final_jxl_buf_len; + } + + // dirty hack: skip first step of decoding loop since + // we already did it in do...while above + goto decoder_loop_skip_process; + } + + } while (decp->status != JXL_DEC_FRAME); + + // couldn't determine Image mode or it is unsupported + if (!strcmp(decp->mode, "uns") || !decp->mode) { + PyErr_SetString(PyExc_NotImplementedError, "only 8-bit images are supported"); + goto end_with_custom_error; + } + + if (decp->basic_info.have_animation) { + // get frame count by iterating over image out events + if (!_jxl_decoder_count_frames((PyObject *)decp)) { + PyErr_SetString(PyExc_OSError, "something went wrong when counting frames"); + goto end_with_custom_error; + } + } + + return (PyObject *)decp; + // Py_RETURN_NONE; + + // on success we should never reach here + + // set error message + char err_msg[128]; + +end: + snprintf( + err_msg, + 128, + "could not create decoder object. libjxl call: %s returned: %d", + jxl_call_name, + decp->status + ); + PyErr_SetString(PyExc_OSError, err_msg); + +end_with_custom_error: + + // deallocate + _jxl_decoder_dealloc((PyObject *)decp); + PyObject_Del(decp); + + return NULL; +} + +PyObject * +_jxl_decoder_get_info(PyObject *self) { + PILJpegXlDecoderObject *decp = (PILJpegXlDecoderObject *)self; + + return Py_BuildValue( + "IIsiIIII", + decp->basic_info.xsize, + decp->basic_info.ysize, + decp->mode, + decp->basic_info.have_animation, + decp->basic_info.animation.tps_numerator, + decp->basic_info.animation.tps_denominator, + decp->basic_info.animation.num_loops, + decp->n_frames + ); +} + +PyObject * +_jxl_decoder_get_next(PyObject *self) { + PILJpegXlDecoderObject *decp = (PILJpegXlDecoderObject *)self; + PyObject *bytes; + PyObject *ret; + JxlFrameHeader fhdr = {}; + + char *jxl_call_name; + + // process events until next frame output is ready + while (decp->status != JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + decp->status = JxlDecoderProcessInput(decp->decoder); + + // every frame was decoded successfully + if (decp->status == JXL_DEC_SUCCESS) { + Py_RETURN_NONE; + } + + // this should only occur after rewind + if (decp->status == JXL_DEC_NEED_MORE_INPUT) { + _jxl_decoder_set_input((PyObject *)decp); + _PIL_JXL_CHECK("JxlDecoderSetInput") + continue; + } + + if (decp->status == JXL_DEC_FRAME) { + // decode frame header + decp->status = JxlDecoderGetFrameHeader(decp->decoder, &fhdr); + _PIL_JXL_CHECK("JxlDecoderGetFrameHeader"); + continue; + } + } + + size_t new_outbuf_len; + decp->status = JxlDecoderImageOutBufferSize( + decp->decoder, &decp->pixel_format, &new_outbuf_len + ); + _PIL_JXL_CHECK("JxlDecoderImageOutBufferSize"); + + // only allocate memory when current buffer is too small + if (decp->outbuf_len < new_outbuf_len) { + decp->outbuf_len = new_outbuf_len; + uint8_t *_new_outbuf = realloc(decp->outbuf, decp->outbuf_len); + if (!_new_outbuf) { + PyErr_SetString(PyExc_OSError, "failed to allocate outbuf"); + goto end_with_custom_error; + } + decp->outbuf = _new_outbuf; + } + + decp->status = JxlDecoderSetImageOutBuffer( + decp->decoder, &decp->pixel_format, decp->outbuf, decp->outbuf_len + ); + _PIL_JXL_CHECK("JxlDecoderSetImageOutBuffer"); + + // decode image into output_buffer + decp->status = JxlDecoderProcessInput(decp->decoder); + + if (decp->status != JXL_DEC_FULL_IMAGE) { + PyErr_SetString(PyExc_OSError, "failed to read next frame"); + goto end_with_custom_error; + } + + bytes = PyBytes_FromStringAndSize((char *)(decp->outbuf), decp->outbuf_len); + + ret = Py_BuildValue("SIi", bytes, fhdr.duration, fhdr.is_last); + + Py_DECREF(bytes); + return ret; + + // we also shouldn't reach here if frame read was ok + + // set error message + char err_msg[128]; + +end: + snprintf( + err_msg, + 128, + "could not read frame. libjxl call: %s returned: %d", + jxl_call_name, + decp->status + ); + PyErr_SetString(PyExc_OSError, err_msg); + +end_with_custom_error: + + // no need to deallocate anything here + // user can just ignore error + + return NULL; +} + +PyObject * +_jxl_decoder_get_icc(PyObject *self) { + PILJpegXlDecoderObject *decp = (PILJpegXlDecoderObject *)self; + + if (!decp->jxl_icc) + Py_RETURN_NONE; + + return PyBytes_FromStringAndSize((const char *)decp->jxl_icc, decp->jxl_icc_len); +} + +PyObject * +_jxl_decoder_get_exif(PyObject *self) { + PILJpegXlDecoderObject *decp = (PILJpegXlDecoderObject *)self; + + if (!decp->jxl_exif) + Py_RETURN_NONE; + + return PyBytes_FromStringAndSize((const char *)decp->jxl_exif, decp->jxl_exif_len); +} + +PyObject * +_jxl_decoder_get_xmp(PyObject *self) { + PILJpegXlDecoderObject *decp = (PILJpegXlDecoderObject *)self; + + if (!decp->jxl_xmp) + Py_RETURN_NONE; + + return PyBytes_FromStringAndSize((const char *)decp->jxl_xmp, decp->jxl_xmp_len); +} + +// PILJpegXlDecoder methods +static struct PyMethodDef _jpegxl_decoder_methods[] = { + {"get_info", (PyCFunction)_jxl_decoder_get_info, METH_NOARGS, "get_info"}, + {"get_next", (PyCFunction)_jxl_decoder_get_next, METH_NOARGS, "get_next"}, + {"get_icc", (PyCFunction)_jxl_decoder_get_icc, METH_NOARGS, "get_icc"}, + {"get_exif", (PyCFunction)_jxl_decoder_get_exif, METH_NOARGS, "get_exif"}, + {"get_xmp", (PyCFunction)_jxl_decoder_get_xmp, METH_NOARGS, "get_xmp"}, + {"rewind", (PyCFunction)_jxl_decoder_rewind, METH_NOARGS, "rewind"}, + {NULL, NULL} /* sentinel */ +}; + +// PILJpegXlDecoder type definition +static PyTypeObject PILJpegXlDecoder_Type = { + PyVarObject_HEAD_INIT(NULL, 0) "PILJpegXlDecoder", /*tp_name */ + sizeof(PILJpegXlDecoderObject), /*tp_basicsize */ + 0, /*tp_itemsize */ + /* methods */ + (destructor)_jxl_decoder_dealloc, /*tp_dealloc*/ + 0, /*tp_vectorcall_offset*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_as_async*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT, /*tp_flags*/ + 0, /*tp_doc*/ + 0, /*tp_traverse*/ + 0, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + _jpegxl_decoder_methods, /*tp_methods*/ + 0, /*tp_members*/ + 0, /*tp_getset*/ +}; + +// Return libjxl decoder version available as integer: +// MAJ*1_000_000 + MIN*1_000 + PATCH +PyObject * +JpegXlDecoderVersion_wrapper() { + return Py_BuildValue("i", JxlDecoderVersion()); +} + +// Version as string +const char * +JpegXlDecoderVersion_str(void) { + static char version[20]; + int version_number = JxlDecoderVersion(); + sprintf( + version, + "%d.%d.%d", + version_number / 1000000, + (version_number % 1000000) / 1000, + (version_number % 1000) + ); + return version; +} + +static PyMethodDef jpegxlMethods[] = { + {"JpegXlDecoderVersion", JpegXlDecoderVersion_wrapper, METH_NOARGS, "JpegXlVersion" + }, + {"PILJpegXlDecoder", _jxl_decoder_new, METH_VARARGS, "PILJpegXlDecoder"}, + {NULL, NULL} +}; + +static int +setup_module(PyObject *m) { + if (PyType_Ready(&PILJpegXlDecoder_Type) < 0) { + return -1; + } + + // TODO(oloke) ready object types? + PyObject *d = PyModule_GetDict(m); + + PyObject *v = PyUnicode_FromString(JpegXlDecoderVersion_str()); + PyDict_SetItemString(d, "libjxl_version", v ? v : Py_None); + Py_XDECREF(v); + + return 0; +} + +PyMODINIT_FUNC +PyInit__jpegxl(void) { + PyObject *m; + + static PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_jpegxl", /* m_name */ + NULL, /* m_doc */ + -1, /* m_size */ + jpegxlMethods, /* m_methods */ + }; + + m = PyModule_Create(&module_def); + if (setup_module(m) < 0) { + Py_DECREF(m); + return NULL; + } + + return m; +}