diff --git a/keras/backend/jax/core.py b/keras/backend/jax/core.py index 7dfd103ebd8..983caf57fcd 100644 --- a/keras/backend/jax/core.py +++ b/keras/backend/jax/core.py @@ -347,6 +347,10 @@ def unstack(x, num=None, axis=0): ] +def custom_gradient(fun): + return jax.custom_gradient(fun=fun) + + def device_scope(device_name): if isinstance(device_name, str): # We support string value like "cpu:0", "gpu:1", etc. diff --git a/keras/backend/numpy/core.py b/keras/backend/numpy/core.py index 38ef5bae6ec..893e0497f6e 100644 --- a/keras/backend/numpy/core.py +++ b/keras/backend/numpy/core.py @@ -229,3 +229,9 @@ def stop_gradient(x): def unstack(x, num=None, axis=0): x = np.moveaxis(x, axis, 0) return [x[i] for i in range(x.shape[0])] + + +def custom_gradient(fun): + raise NotImplementedError( + "`custom_gradient` is not supported with numpy backend" + ) diff --git a/keras/backend/tensorflow/core.py b/keras/backend/tensorflow/core.py index df0bf438e4e..d8edef8cb43 100644 --- a/keras/backend/tensorflow/core.py +++ b/keras/backend/tensorflow/core.py @@ -256,6 +256,10 @@ def unstack(x, num=None, axis=0): return tf.unstack(x, num=num, axis=axis) +def custom_gradient(fun): + return tf.custom_gradient(f=fun) + + class name_scope(base_name_scope): def __init__(self, name, **kwargs): super().__init__(name, **kwargs) diff --git a/keras/backend/torch/core.py b/keras/backend/torch/core.py index 8fac0b97fad..3ba44057b93 100644 --- a/keras/backend/torch/core.py +++ b/keras/backend/torch/core.py @@ -435,3 +435,10 @@ def stop_gradient(variable): def unstack(x, num=None, axis=0): return x.unbind(axis) + + +def custom_gradient(fun): + # TODO: Support this function + raise NotImplementedError( + "`custom_gradient` is not supported with torch backend" + ) diff --git a/keras/ops/core.py b/keras/ops/core.py index 775e6299ea8..53d7cfc893a 100644 --- a/keras/ops/core.py +++ b/keras/ops/core.py @@ -11,6 +11,7 @@ convert_to_numpy cond is_tensor +custom_gradient """ import numpy as np @@ -623,3 +624,47 @@ def is_tensor(x): `True` if `x` is a tensor, otherwise `False`. """ return backend.core.is_tensor(x) + + +@keras_export("keras.ops.custom_gradient") +def custom_gradient(f): + """Decorator to define a function with a custom gradient. + + This decorator allows fine grained control over the gradients of a sequence + for operations. This may be useful for multiple reasons, including providing + a more efficient or numerically stable gradient for a sequence of + operations. + + Note that `custom_gradient` only supports TensorFlow and JAX backends. + + Args: + f: Function `f(*x)` that returns a tuple `(y, grad_fn)` where: + - `x` is a sequence of (nested structures of) tensor inputs to the + function. + - `y` is a (nested structure of) tensor outputs of applying + operations in `f` to `x`. + - `grad_fn` is a function with the signature `g(*grad_ys)` which + returns a list of tensors the same size as (flattened) `x`: the + derivatives of tensors in `y` with respect to the tensors in + `x`. `grad_ys` is a sequence of tensors the same size as + (flattened) `y` holding the initial value gradients for each + tensor in `y`. + + Returns: + A function `h(x)` which returns the same value as `f(x)[0]` and whose + gradient is determined by `f(x)[1]`. + + Example: + + ```python + @ops.custom_gradient + def log1pexp(x): + e = ops.exp(x) + + def grad(upstream): + return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e)) + + return ops.log(1 + e), grad + ``` + """ + return backend.core.custom_gradient(f) diff --git a/keras/ops/core_test.py b/keras/ops/core_test.py index c7d90f7af90..a4f03bfa1be 100644 --- a/keras/ops/core_test.py +++ b/keras/ops/core_test.py @@ -500,6 +500,44 @@ def test_is_tensor(self): self.assertTrue(ops.is_tensor(x)) self.assertFalse(ops.is_tensor([1, 2, 3])) + @pytest.mark.skipif( + backend.backend() not in ("tensorflow", "jax"), + reason=f"{backend.backend()} doesn't support `custom_gradient`.", + ) + def test_custom_gradient(self): + @ops.custom_gradient + def log1pexp(x): + e = ops.exp(x) + + def grad(upstream): + return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e)) + + return ops.log(1 + e), grad + + def log1pexp_nan(x): + return ops.log(1 + ops.exp(x)) + + x = ops.convert_to_tensor(100.0) + if backend.backend() == "tensorflow": + import tensorflow as tf + + with tf.GradientTape() as tape1: + tape1.watch(x) + y = log1pexp(x) + with tf.GradientTape() as tape2: + tape2.watch(x) + z = log1pexp_nan(x) + dy_dx = tape1.gradient(y, x) + dz_dx = tape2.gradient(z, x) + elif backend.backend() == "jax": + import jax + + dy_dx = jax.grad(log1pexp)(x) + dz_dx = jax.grad(log1pexp_nan)(x) + + self.assertEqual(ops.convert_to_numpy(dy_dx), 1.0) + self.assertTrue(ops.isnan(dz_dx)) + class CoreOpsDtypeTest(testing.TestCase, parameterized.TestCase): import jax # enable bfloat16 for numpy