From 645c2fd12a369c8e14fa2693de29b155302d3620 Mon Sep 17 00:00:00 2001 From: Michael Mauderer Date: Mon, 2 Dec 2024 12:33:04 +0000 Subject: [PATCH] Refactor to use LUT sequence operators for executing a CLF workflow. --- colour/io/luts/clf.py | 403 +++++++++++------- .../lut1_with_half_domain_sample.xml | 0 colour/io/luts/tests/clf_apply/test_lut1d.py | 2 +- colour/io/luts/tests/test_clf_common.py | 4 +- 4 files changed, 246 insertions(+), 163 deletions(-) rename colour/io/luts/tests/{resources/clf => clf_apply/resources}/lut1_with_half_domain_sample.xml (100%) diff --git a/colour/io/luts/clf.py b/colour/io/luts/clf.py index c3945403d..bfcd3c1bf 100644 --- a/colour/io/luts/clf.py +++ b/colour/io/luts/clf.py @@ -3,16 +3,20 @@ """ import colour_clf_io as clf import numpy as np +from numpy.typing import ArrayLike, NDArray +from colour import LUTSequence from colour.algebra import ( table_interpolation_tetrahedral, table_interpolation_trilinear, ) from colour.hints import ( + Any, NDArrayFloat, + ProtocolLUTSequenceItem, ) -from colour.io.luts import LUT1D, LUT3D -from colour.utilities import tsplit, tstack +from colour.io import AbstractLUTSequenceOperator, luts +from colour.utilities import as_float_array, tsplit, tstack __all__ = ["apply"] @@ -42,7 +46,7 @@ def from_f16_to_uint16(array: npt.NDArray[np.float16]) -> npt.NDArray[np.uint16] return array # type: ignore -def apply_by_channel(value, f, params, extra_args=None): +def apply_by_channel(value, f, params, extra_args=None) -> NDArray: if params is None or len(params) == 0: return f(value, params, extra_args) elif len(params) == 1 and params[0].channel is None: @@ -69,43 +73,90 @@ def get_interpolator_for_LUT3D(node: clf.LUT3D): raise NotImplementedError -def apply_LUT3D(node: clf.LUT3D, normalised_value: NDArrayFloat) -> NDArrayFloat: - table = node.array.as_array() - size = node.array.dim[0] - if node.raw_halfs: - table = from_uint16_to_f16(table) - if node.half_domain: - normalised_value = np.array(normalised_value, dtype=np.float16) - normalised_value = from_f16_to_uint16(normalised_value) / (size - 1) - # We need to map to indices, where 1 indicates the last element in the LUT array. - value_scaled = normalised_value * (size - 1) - extrapolator_kwargs = {"method": "Constant"} - interpolator = get_interpolator_for_LUT3D(node) - lut = LUT3D(table, size=size) - return lut.apply( - value_scaled, extrapolator_kwargs=extrapolator_kwargs, interpolator=interpolator - ) - - -def apply_LUT1D(node: clf.LUT1D, normalised_value: NDArrayFloat) -> NDArrayFloat: - table = node.array.as_array() - size = node.array.dim[0] - if node.raw_halfs: - table = from_uint16_to_f16(table) - if node.half_domain: - normalised_value = np.array(normalised_value, dtype=np.float16) - normalised_value = from_f16_to_uint16(normalised_value) / (size - 1) - domain = np.min(table), np.max(table) - # We need to map to indices, where 1 indicates the last element in the LUT array. - value_scaled = normalised_value * (size - 1) - lut = LUT1D(table, size=size, domain=domain) - extrapolator_kwargs = {"method": "Constant"} - return lut.apply(value_scaled, extrapolator_kwargs=extrapolator_kwargs) - - -def apply_matrix(node: clf.Matrix, value: NDArrayFloat) -> NDArrayFloat: - matrix = node.array.as_array() - return matrix.dot(value) +class CLFNode(AbstractLUTSequenceOperator): + node: clf.ProcessNode + + def __init__(self, node: clf.ProcessNode): + super().__init__(node.name, [node.description]) + self.node = node + + def from_input_range(self, value): + return value + + def to_output_range(self, value): + return value / self.node.out_bit_depth.scale_factor() + + +class LUT3D(CLFNode): + node: clf.LUT3D + + def __init__(self, node: clf.LUT3D): + super().__init__(node) + self.node = node + + def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002 + RGB = self.from_input_range(RGB) + node = self.node + table = node.array.as_array() + size = node.array.dim[0] + if node.raw_halfs: + table = from_uint16_to_f16(table) + if node.half_domain: + RGB = np.array(RGB, dtype=np.float16) + RGB = from_f16_to_uint16(RGB) / (size - 1) + # We need to map to indices, where 1 indicates the last element in the + # LUT array. + value_scaled = RGB * (size - 1) + extrapolator_kwargs = {"method": "Constant"} + interpolator = get_interpolator_for_LUT3D(node) + lut = luts.LUT3D(table, size=size) + out = lut.apply( + value_scaled, + extrapolator_kwargs=extrapolator_kwargs, + interpolator=interpolator, + ) + out = self.to_output_range(out) + return out + + +class LUT1D(CLFNode): + node: clf.LUT1D + + def __init__(self, node: clf.LUT1D): + super().__init__(node) + self.node = node + + def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002 + RGB = self.from_input_range(RGB) + table = self.node.array.as_array() + size = self.node.array.dim[0] + if self.node.raw_halfs: + table = from_uint16_to_f16(table) + if self.node.half_domain: + RGB = np.array(RGB, dtype=np.float16) + RGB = from_f16_to_uint16(RGB) / (size - 1) + domain = np.min(table), np.max(table) + # We need to map to indices, where 1 indicates the last element in the + # LUT array. + value_scaled = RGB * (size - 1) + lut = luts.LUT1D(table, size=size, domain=domain) + extrapolator_kwargs = {"method": "Constant"} + out = lut.apply(value_scaled, extrapolator_kwargs=extrapolator_kwargs) + out = self.to_output_range(out) + return out + + +class Matrix(CLFNode): + node: clf.Matrix + + def __init__(self, node: clf.Matrix): + super().__init__(node) + self.node = node + + def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002 + RGB = self.from_input_range(RGB) + matrix = self.node.array.as_array() + return matrix.dot(RGB) def assert_range_correct(in_out, bit_depth_scale): @@ -119,36 +170,46 @@ def assert_range_correct(in_out, bit_depth_scale): ) -def apply_range(node: clf.Range, normalised_value: NDArrayFloat): - value = normalised_value * node.in_bit_depth.scale_factor() - max_in = node.max_in_value - max_out = node.max_out_value - max_in_out = node.max_in_value, node.max_out_value - min_in = node.min_in_value - min_out = node.min_out_value - min_in_out = node.min_in_value, node.min_out_value - do_clamping = node.style is None or node.style == node.style.CLAMP - - if None in max_in_out or None in min_in_out: - if not do_clamping: - raise ValueError( - "Inconsistent settings in range node. " - "Clamping was not set, but not all values to calculate a " - "range are supplied. " +class Range(CLFNode): + node: clf.LUT1D + + def __init__(self, node: clf.LUT1D): + super().__init__(node) + self.node = node + + def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002 + node = self.node + value = RGB * self.node.in_bit_depth.scale_factor() + max_in = node.max_in_value + max_out = node.max_out_value + max_in_out = node.max_in_value, node.max_out_value + min_in = node.min_in_value + min_out = node.min_out_value + min_in_out = node.min_in_value, node.min_out_value + do_clamping = node.style is None or node.style == node.style.CLAMP + + if None in max_in_out or None in min_in_out: + if not do_clamping: + raise ValueError( + "Inconsistent settings in range node. " + "Clamping was not set, but not all values to calculate a " + "range are supplied. " + ) + bit_depth_scale = ( + node.out_bit_depth.scale_factor() / node.in_bit_depth.scale_factor() ) - bit_depth_scale = ( - node.out_bit_depth.scale_factor() / node.in_bit_depth.scale_factor() - ) - assert_range_correct(min_in_out, bit_depth_scale) - assert_range_correct(max_in_out, bit_depth_scale) - scaled_value = value * bit_depth_scale - return np.clip(scaled_value, min_out, max_out) - else: - scale = (max_out - min_out) / (max_in - min_in) - result = value * scale + min_out - min_in * scale - if do_clamping: - result = np.clip(result, min_out, max_out) - return result + assert_range_correct(min_in_out, bit_depth_scale) + assert_range_correct(max_in_out, bit_depth_scale) + scaled_value = value * bit_depth_scale + out = np.clip(scaled_value, min_out, max_out) + else: + scale = (max_out - min_out) / (max_in - min_in) + result = value * scale + min_out - min_in * scale + if do_clamping: + result = np.clip(result, min_out, max_out) + out = result + out = self.to_output_range(out) + return out FLT_MIN = 1.175494e-38 @@ -240,16 +301,27 @@ def apply_log_internal(value: NDArrayFloat, params, extra_args) -> NDArrayFloat: raise ValueError(f"Invalid Log Style: {style}") -def apply_log(node: clf.Log, normalised_value: NDArrayFloat) -> NDArrayFloat: - style = node.style - params = node.log_params - extra_args = style, node.in_bit_depth, node.out_bit_depth - return apply_by_channel( - normalised_value, - apply_log_internal, - params, - extra_args, - ) +class Log(CLFNode): + node: clf.Log + + def __init__(self, node: clf.Log): + super().__init__(node) + self.node = node + + def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002 + RGB = self.from_input_range(RGB) + node = self.node + style = node.style + params = node.log_params + extra_args = style, node.in_bit_depth, node.out_bit_depth + out = apply_by_channel( + RGB, + apply_log_internal, + params, + extra_args, + ) + out = self.to_output_range(out) + return out def mon_curve_forward(x, exponent, offset): @@ -307,12 +379,21 @@ def apply_exponent_internal( raise ValueError(f"Invalid Exponent Style: {style}") -def apply_exponent(node: clf.Exponent, normalised_value: NDArrayFloat) -> NDArrayFloat: - style = node.style - params = node.exponent_params - return apply_by_channel( - normalised_value, apply_exponent_internal, params, extra_args=style - ) +class Exponent(CLFNode): + node: clf.Exponent + + def __init__(self, node: clf.Exponent): + super().__init__(node) + self.node = node + + def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002 + node = self.node + RGB = self.from_input_range(RGB) + style = node.style + params = node.exponent_params + out = apply_by_channel(RGB, apply_exponent_internal, params, extra_args=style) + out = self.to_output_range(out) + return out def asc_cdl_luma(value): @@ -323,96 +404,98 @@ def asc_cdl_luma(value): return luma -def apply_asc_cdl(node: clf.ASC_CDL, normalised_value: NDArrayFloat): - sop = node.sopnode - if sop is None: - slope = np.array([1.0, 1.0, 1.0]) - offset = np.array([0.0, 0.0, 0.0]) - power = np.array([1.0, 1.0, 1.0]) - else: - slope = np.array(sop.slope) - offset = np.array(sop.offset) - power = np.array(sop.power) - saturation = 1.0 if node.sat_node is None else node.sat_node.saturation - - def clamp(x): - return np.clip(x, 0.0, 1.0) - - match node.style: - case clf.ASC_CDL_Style.FWD: - out_sop = ( - clamp( - normalised_value * slope + offset, +class ASC_CDL(CLFNode): + node: clf.ASC_CDL + + def __init__(self, node: clf.ASC_CDL): + super().__init__(node) + self.node = node + + def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002 + node = self.node + RGB = self.from_input_range(RGB) + sop = node.sopnode + if sop is None: + slope = np.array([1.0, 1.0, 1.0]) + offset = np.array([0.0, 0.0, 0.0]) + power = np.array([1.0, 1.0, 1.0]) + else: + slope = np.array(sop.slope) + offset = np.array(sop.offset) + power = np.array(sop.power) + saturation = 1.0 if node.sat_node is None else node.sat_node.saturation + + def clamp(x): + return np.clip(x, 0.0, 1.0) + + match node.style: + case clf.ASC_CDL_Style.FWD: + out_sop = ( + clamp( + RGB * slope + offset, + ) + ** power ) - ** power - ) - R, G, B = tsplit(out_sop) - luma = asc_cdl_luma(out_sop) - return clamp(luma + saturation * (out_sop - luma)) - case clf.ASC_CDL_Style.FWD_NO_CLAMP: - lin = normalised_value * slope + offset - out_sop = np.where(lin >= 0, lin**power, lin) - luma = asc_cdl_luma(out_sop) - return luma + saturation * (out_sop - luma) - case clf.ASC_CDL_Style.REV: - in_clamp = clamp(normalised_value) - luma = asc_cdl_luma(in_clamp) - out_sat = luma + (in_clamp - luma) / saturation - return clamp((clamp(out_sat) ** (1.0 / power) - offset) / slope) - case clf.ASC_CDL_Style.REV_NO_CLAMP: - luma = asc_cdl_luma(normalised_value) - out_sat = luma + (normalised_value - luma) / saturation - out_pw = np.where(out_sat >= 0, (out_sat) ** (1 / power), out_sat) - return (out_pw - offset) / slope - case _: - raise ValueError(f"Invalid ASC_CDL Style: {node.style}") - - -def apply_proces_node( - node: clf.ProcessNode, normalised_value: NDArrayFloat -) -> NDArrayFloat: + R, G, B = tsplit(out_sop) + luma = asc_cdl_luma(out_sop) + out = clamp(luma + saturation * (out_sop - luma)) + case clf.ASC_CDL_Style.FWD_NO_CLAMP: + lin = as_float_array(RGB * slope + offset) + out_sop = np.where(lin >= 0, lin**power, lin) + luma = asc_cdl_luma(out_sop) + out = luma + saturation * (out_sop - luma) + case clf.ASC_CDL_Style.REV: + in_clamp = clamp(RGB) + luma = asc_cdl_luma(in_clamp) + out_sat = luma + (in_clamp - luma) / saturation + out = clamp((clamp(out_sat) ** (1.0 / power) - offset) / slope) + case clf.ASC_CDL_Style.REV_NO_CLAMP: + luma = asc_cdl_luma(RGB) + out_sat = luma + (RGB - luma) / saturation + out_pw = np.where(out_sat >= 0, (out_sat) ** (1 / power), out_sat) + out = (out_pw - offset) / slope + case _: + raise ValueError(f"Invalid ASC_CDL Style: {node.style}") + out = self.to_output_range(out) + return out + + +def as_LUT_sequence_item(node: clf.ProcessNode) -> ProtocolLUTSequenceItem: if isinstance(node, clf.LUT1D): - return apply_LUT1D(node, normalised_value) + return LUT1D(node) if isinstance(node, clf.LUT3D): - return apply_LUT3D(node, normalised_value) + return LUT3D(node) if isinstance(node, clf.Matrix): - return apply_matrix(node, normalised_value) + return Matrix(node) if isinstance(node, clf.Range): - return apply_range(node, normalised_value) + return Range(node) if isinstance(node, clf.Log): - return apply_log(node, normalised_value) + return Log(node) if isinstance(node, clf.Exponent): - return apply_exponent(node, normalised_value) + return Exponent(node) if isinstance(node, clf.ASC_CDL): - return apply_asc_cdl(node, normalised_value) - - raise RuntimeError("No matching process node found") # TODO: Better error handling - - -def apply_next_node( - process_list: clf.ProcessList, - value: NDArrayFloat, - use_normalised_values: bool, -) -> NDArrayFloat: - next_node = process_list.process_nodes.pop(0) - if not use_normalised_values: - value = value / next_node.in_bit_depth.scale_factor() - result = apply_proces_node(next_node, value) - if use_normalised_values: - result = result / next_node.out_bit_depth.scale_factor() - return result + return ASC_CDL(node) + raise RuntimeError(f"No matching process node found for {node}.") def apply( process_list: clf.ProcessList, value: NDArrayFloat, - use_normalised_values=False, + normalised_values=False, ) -> NDArrayFloat: """Apply the transformation described by the given ProcessList to the given value. """ - result = value - while process_list.process_nodes: - result = apply_next_node(process_list, result, use_normalised_values) - use_normalised_values = False + if not normalised_values: + value = value / process_list.process_nodes[0].in_bit_depth.scale_factor() + + lut_sequence_items = [ + as_LUT_sequence_item(node) for node in process_list.process_nodes + ] + sequence = LUTSequence(*lut_sequence_items) + result = sequence.apply(value) + + if not normalised_values: + result = result * process_list.process_nodes[-1].out_bit_depth.scale_factor() + return result diff --git a/colour/io/luts/tests/resources/clf/lut1_with_half_domain_sample.xml b/colour/io/luts/tests/clf_apply/resources/lut1_with_half_domain_sample.xml similarity index 100% rename from colour/io/luts/tests/resources/clf/lut1_with_half_domain_sample.xml rename to colour/io/luts/tests/clf_apply/resources/lut1_with_half_domain_sample.xml diff --git a/colour/io/luts/tests/clf_apply/test_lut1d.py b/colour/io/luts/tests/clf_apply/test_lut1d.py index f7165ea52..825e546fa 100644 --- a/colour/io/luts/tests/clf_apply/test_lut1d.py +++ b/colour/io/luts/tests/clf_apply/test_lut1d.py @@ -162,7 +162,7 @@ def test_ocio_consistency_half_domain(self): consistent with `ociochecklut`. """ value_rgb = np.array([1.0, 0.5, 0.0]) - path = os.path.abspath("./resources/clf/lut1_with_half_domain_sample.xml") + path = os.path.abspath("./resources/lut1_with_half_domain_sample.xml") assert_ocio_consistency_for_file(value_rgb, path) diff --git a/colour/io/luts/tests/test_clf_common.py b/colour/io/luts/tests/test_clf_common.py index cea601716..99cc2dcb9 100644 --- a/colour/io/luts/tests/test_clf_common.py +++ b/colour/io/luts/tests/test_clf_common.py @@ -85,7 +85,7 @@ def assert_ocio_consistency(value, snippet: str, err_msg=""): tool for the given input. """ process_list = snippet_to_process_list(snippet) - process_list_output = apply(process_list, value, use_normalised_values=True) + process_list_output = apply(process_list, value, normalised_values=True) value_tuple = value[0], value[1], value[2] ocio_output = ocio_output_for_snippet(snippet, value_tuple) np.testing.assert_array_almost_equal( @@ -100,7 +100,7 @@ def assert_ocio_consistency_for_file(value_rgb, clf_path): from colour_clf_io import read_clf clf_data = read_clf(clf_path) - process_list_output = apply(clf_data, value_rgb, use_normalised_values=True) + process_list_output = apply(clf_data, value_rgb, normalised_values=True) ocio_output = result_as_array(ocio_outout_for_file(clf_path, value_rgb)) np.testing.assert_array_almost_equal(process_list_output, ocio_output)