diff --git a/.gitignore b/.gitignore index 2568b8c..a7258d0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,11 @@ /target Cargo.lock +# for test +*.ipynb +*.jxl +*.jpeg + # Byte-compiled / optimized / DLL files __pycache__/ .pytest_cache/ diff --git a/pillow_jxl/JpegXLImagePlugin.py b/pillow_jxl/JpegXLImagePlugin.py index fdac160..d0586cc 100644 --- a/pillow_jxl/JpegXLImagePlugin.py +++ b/pillow_jxl/JpegXLImagePlugin.py @@ -27,7 +27,11 @@ def _open(self): self.fc = self.fp.read() self._decoder = Decoder() - self._jxlinfo, self._data = self._decoder(self.fc) + self.jpeg, self._jxlinfo, self._data = self._decoder(self.fc) + # FIXME (Isotr0py): Maybe slow down jpeg reconstruction + if self.jpeg: + with Image.open(BytesIO(self._data)) as im: + self._data = im.tobytes() self._size = (self._jxlinfo.width, self._jxlinfo.height) self.rawmode = self._jxlinfo.mode # NOTE (Isotr0py): PIL 10.1.0 changed the mode to property, use _mode instead @@ -85,7 +89,12 @@ def _save(im, fp, filename, save_all=False): decoding_speed=decoding_speed, use_container=use_container, ) - data = enc(im.tobytes(), im.width, im.height) + # FIXME (Isotr0py): im.filename maybe None if parse stream + if im.format == "JPEG" and im.filename: + with open(im.filename, "rb") as f: + data = enc(f.read(), im.width, im.height, jpeg_encode=True) + else: + data = enc(im.tobytes(), im.width, im.height, jpeg_encode=False) fp.write(data) diff --git a/pillow_jxl/__init__.pyi b/pillow_jxl/__init__.pyi index 6e94cc7..055561d 100644 --- a/pillow_jxl/__init__.pyi +++ b/pillow_jxl/__init__.pyi @@ -17,7 +17,16 @@ class Encoder: lossless: bool = True, quality: float = 0.0): ... - def __call__(self, data: bytes, width: int, height: int) -> bytes: ... + def __call__(self, data: bytes, width: int, height: int, jpeg_encode: bool) -> bytes: ... + ''' + Encode a jpeg-xl image. + + Args: + data(`bytes`): raw image bytes + + Return: + `bytes`: The encoded jpeg-xl image. + ''' class Decoder: @@ -30,7 +39,7 @@ class Decoder: def __init__(self, parallel: bool = True): ... - def __call__(self, data: bytes) -> (ImageInfo, bytes): ... + def __call__(self, data: bytes) -> (bool, ImageInfo, bytes): ... ''' Decode a jpeg-xl image. @@ -38,6 +47,7 @@ class Decoder: data(`bytes`): jpeg-xl image Return: + `bool`: If the jpeg is reconstructed `ImageInfo`: The metadata of decoded image `bytes`: The decoded image. ''' diff --git a/src/lib.rs b/src/lib.rs index 5beb7f2..9bbd971 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ use pyo3::prelude::*; use pyo3::types::PyBytes; -use jpegxl_rs::decode::{Metadata, Pixels}; +use jpegxl_rs::decode::{Metadata, Pixels, Data}; use jpegxl_rs::encode::EncoderResult; use jpegxl_rs::parallel::threads_runner::ThreadsRunner; use jpegxl_rs::{decoder_builder, encoder_builder}; @@ -52,8 +52,8 @@ impl Encoder { } } - #[pyo3(signature = (data, width, height))] - fn __call__<'a>(&'a self, _py: Python<'a>, data: &[u8], width: u32, height: u32) -> &PyBytes { + #[pyo3(signature = (data, width, height, jpeg_encode))] + fn __call__<'a>(&'a self, _py: Python<'a>, data: &[u8], width: u32, height: u32, jpeg_encode: bool) -> &PyBytes { let parallel_runner: ThreadsRunner; let mut encoder = match self.parallel { true => { @@ -71,7 +71,10 @@ impl Encoder { encoder.quality = self.quality; encoder.use_container = self.use_container; encoder.decoding_speed = self.decoding_speed; - let buffer: EncoderResult = encoder.encode(&data, width, height).unwrap(); + let buffer: EncoderResult = match jpeg_encode { + true => encoder.encode_jpeg(&data).unwrap(), + false => encoder.encode(&data, width, height).unwrap(), + }; PyBytes::new(_py, &buffer.data) } @@ -133,7 +136,7 @@ impl Decoder { } #[pyo3(signature = (data))] - fn __call__<'a>(&'a self, _py: Python<'a>, data: &[u8]) -> (ImageInfo, &PyBytes) { + fn __call__<'a>(&'a self, _py: Python<'a>, data: &[u8]) -> (bool, ImageInfo, &PyBytes) { let parallel_runner: ThreadsRunner; let decoder = match self.parallel { true => { @@ -145,12 +148,13 @@ impl Decoder { } false => decoder_builder().build().unwrap(), }; - let (info, img) = decoder.decode(&data).unwrap(); - let img: Vec = match img { - Pixels::Uint8(x) => x, + let (info, img) = decoder.reconstruct(&data).unwrap(); + let (jpeg, img) = match img { + Data::Jpeg(x) => (true, x), + Data::Pixels(Pixels::Uint8(x)) => (false, x), _ => panic!("Unsupported dtype for decoding"), }; - (ImageInfo::from(info), PyBytes::new(_py, &img)) + (jpeg, ImageInfo::from(info), PyBytes::new(_py, &img)) } fn __repr__(&self) -> PyResult {