Skip to content

Commit

Permalink
convert jxl error to python exceptions (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
Piezoid authored Oct 7, 2024
1 parent b228d51 commit fed1562
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 27 deletions.
22 changes: 14 additions & 8 deletions src/decode.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::borrow::Cow;

use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;

use jpegxl_rs::decode::{Data, Metadata, Pixels};
use jpegxl_rs::decoder_builder;
use jpegxl_rs::parallel::threads_runner::ThreadsRunner;
use jpegxl_rs::{decoder_builder, DecodeError};

// it works even if the item is not documented:

Expand Down Expand Up @@ -86,7 +87,7 @@ impl Decoder {
&self,
_py: Python,
data: &[u8],
) -> (bool, ImageInfo, Cow<'_, [u8]>, Cow<'_, [u8]>) {
) -> PyResult<(bool, ImageInfo, Cow<'_, [u8]>, Cow<'_, [u8]>)> {
_py.allow_threads(|| self.call_inner(data))
}

Expand All @@ -96,21 +97,22 @@ impl Decoder {
}

impl Decoder {
fn call_inner(&self, data: &[u8]) -> (bool, ImageInfo, Cow<'_, [u8]>, Cow<'_, [u8]>) {
fn call_inner(&self, data: &[u8]) -> PyResult<(bool, ImageInfo, Cow<'_, [u8]>, Cow<'_, [u8]>)> {
let parallel_runner = ThreadsRunner::new(
None,
if self.num_threads < 0 {
None
} else {
Some(self.num_threads as usize)
},
).unwrap();
)
.ok_or_else(|| PyRuntimeError::new_err("Could not create JxlThreadsRunner"))?;
let decoder = decoder_builder()
.icc_profile(true)
.parallel_runner(&parallel_runner)
.build()
.unwrap();
let (info, img) = decoder.reconstruct(&data).unwrap();
.map_err(to_pyjxlerror)?;
let (info, img) = decoder.reconstruct(&data).map_err(to_pyjxlerror)?;
let (jpeg, img) = match img {
Data::Jpeg(x) => (true, x),
Data::Pixels(x) => (false, convert_pixels(x)),
Expand All @@ -119,11 +121,15 @@ impl Decoder {
Some(x) => x.to_vec(),
None => Vec::new(),
};
(
Ok((
jpeg,
ImageInfo::from(info),
Cow::Owned(img),
Cow::Owned(icc_profile),
)
))
}
}

fn to_pyjxlerror(e: DecodeError) -> PyErr {
PyRuntimeError::new_err(e.to_string())
}
52 changes: 34 additions & 18 deletions src/encode.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::borrow::Cow;

use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;

use jpegxl_rs::encode::{ColorEncoding, EncoderFrame, EncoderResult, EncoderSpeed, Metadata};
use jpegxl_rs::encoder_builder;
use jpegxl_rs::parallel::threads_runner::ThreadsRunner;
use jpegxl_rs::{encoder_builder, EncodeError};

#[pyclass(module = "pillow_jxl")]
pub struct Encoder {
Expand Down Expand Up @@ -32,26 +33,34 @@ impl Encoder {
use_container: bool,
use_original_profile: bool,
num_threads: isize,
) -> Self {
) -> PyResult<Self> {
let (num_channels, has_alpha) = match mode {
"RGBA" => (4, true),
"RGB" => (3, false),
"LA" => (2, true),
"L" => (1, false),
_ => panic!("Only RGB, RGBA, L, LA are supported."),
_ => {
return Err(PyValueError::new_err(
"Only RGB, RGBA, L, LA are supported.",
))
}
};

let decoding_speed = match decoding_speed {
0..=4 => decoding_speed,
_ => panic!("Decoding speed must be between 0 and 4"),
_ => {
return Err(PyValueError::new_err(
"Decoding speed must be between 0 and 4",
))
}
};

let use_original_profile = match lossless {
true => true,
false => use_original_profile,
};

Self {
Ok(Self {
num_channels,
has_alpha,
lossless,
Expand All @@ -61,7 +70,7 @@ impl Encoder {
use_container,
use_original_profile,
num_threads,
}
})
}

#[pyo3(signature = (data, width, height, jpeg_encode, exif=None, jumb=None, xmp=None))]
Expand All @@ -75,7 +84,7 @@ impl Encoder {
exif: Option<&[u8]>,
jumb: Option<&[u8]>,
xmp: Option<&[u8]>,
) -> Cow<'_, [u8]> {
) -> PyResult<Cow<'_, [u8]>> {
py.allow_threads(|| self.call_inner(data, width, height, jpeg_encode, exif, jumb, xmp))
}

Expand All @@ -97,15 +106,16 @@ impl Encoder {
exif: Option<&[u8]>,
jumb: Option<&[u8]>,
xmp: Option<&[u8]>,
) -> Cow<'_, [u8]> {
) -> PyResult<Cow<'_, [u8]>> {
let parallel_runner = ThreadsRunner::new(
None,
if self.num_threads < 0 {
None
} else {
Some(self.num_threads as usize)
},
).unwrap();
)
.ok_or_else(|| PyRuntimeError::new_err("Could not create JxlThreadsRunner"))?;
let mut encoder = encoder_builder()
.parallel_runner(&parallel_runner)
.jpeg_quality(self.quality)
Expand All @@ -114,12 +124,12 @@ impl Encoder {
.use_container(self.use_container)
.decoding_speed(self.decoding_speed)
.build()
.unwrap();
.map_err(to_pyjxlerror)?;
encoder.uses_original_profile = self.use_original_profile;
encoder.color_encoding = match self.num_channels {
1 | 2 => ColorEncoding::SrgbLuma,
3 | 4 => ColorEncoding::Srgb,
_ => panic!("Invalid num channels"),
_ => return Err(PyValueError::new_err("Invalid num channels")),
};
encoder.speed = match self.effort {
1 => EncoderSpeed::Lightning,
Expand All @@ -131,30 +141,36 @@ impl Encoder {
7 => EncoderSpeed::Squirrel,
8 => EncoderSpeed::Kitten,
9 => EncoderSpeed::Tortoise,
_ => panic!("Invalid effort"),
_ => return Err(PyValueError::new_err("Invalid effort")),
};
let buffer: EncoderResult<u8> = match jpeg_encode {
true => encoder.encode_jpeg(&data).unwrap(),
true => encoder.encode_jpeg(&data).map_err(to_pyjxlerror)?,
false => {
let frame = EncoderFrame::new(data).num_channels(self.num_channels);
if let Some(exif_data) = exif {
encoder
.add_metadata(&Metadata::Exif(exif_data), true)
.unwrap();
.map_err(to_pyjxlerror)?
}
if let Some(xmp_data) = xmp {
encoder
.add_metadata(&Metadata::Xmp(xmp_data), true)
.unwrap();
.map_err(to_pyjxlerror)?
}
if let Some(jumb_data) = jumb {
encoder
.add_metadata(&Metadata::Jumb(jumb_data), true)
.unwrap();
.map_err(to_pyjxlerror)?
}
encoder.encode_frame(&frame, width, height).unwrap()
encoder
.encode_frame(&frame, width, height)
.map_err(to_pyjxlerror)?
}
};
Cow::Owned(buffer.data)
Ok(Cow::Owned(buffer.data))
}
}

fn to_pyjxlerror(e: EncodeError) -> PyErr {
PyRuntimeError::new_err(e.to_string())
}
5 changes: 4 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use pyo3::prelude::*;
use pyo3::{create_exception, exceptions::PyRuntimeError, prelude::*};

// it works even if the item is not documented:
mod decode;
mod encode;

create_exception!(my_module, JxlException, PyRuntimeError, "Jxl Error");

#[pymodule]
#[pyo3(name = "pillow_jxl")]
fn pillow_jxl(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<decode::Decoder>()?;
m.add_class::<encode::Encoder>()?;
m.add("JxlException", m.py().get_type_bound::<JxlException>())?;
Ok(())
}

0 comments on commit fed1562

Please sign in to comment.