Skip to content

Commit

Permalink
Add custom gradient (#19279)
Browse files Browse the repository at this point in the history
* Add `custom_gradient` to `ops.core`

* Add comment
  • Loading branch information
james77777778 authored Mar 11, 2024
1 parent 404e8f3 commit b97338e
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 0 deletions.
4 changes: 4 additions & 0 deletions keras/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def unstack(x, num=None, axis=0):
]


def custom_gradient(fun):
return jax.custom_gradient(fun=fun)


def device_scope(device_name):
if isinstance(device_name, str):
# We support string value like "cpu:0", "gpu:1", etc.
Expand Down
6 changes: 6 additions & 0 deletions keras/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,9 @@ def stop_gradient(x):
def unstack(x, num=None, axis=0):
x = np.moveaxis(x, axis, 0)
return [x[i] for i in range(x.shape[0])]


def custom_gradient(fun):
raise NotImplementedError(
"`custom_gradient` is not supported with numpy backend"
)
4 changes: 4 additions & 0 deletions keras/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ def unstack(x, num=None, axis=0):
return tf.unstack(x, num=num, axis=axis)


def custom_gradient(fun):
return tf.custom_gradient(f=fun)


class name_scope(base_name_scope):
def __init__(self, name, **kwargs):
super().__init__(name, **kwargs)
Expand Down
7 changes: 7 additions & 0 deletions keras/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,10 @@ def stop_gradient(variable):

def unstack(x, num=None, axis=0):
return x.unbind(axis)


def custom_gradient(fun):
# TODO: Support this function
raise NotImplementedError(
"`custom_gradient` is not supported with torch backend"
)
45 changes: 45 additions & 0 deletions keras/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
convert_to_numpy
cond
is_tensor
custom_gradient
"""

import numpy as np
Expand Down Expand Up @@ -623,3 +624,47 @@ def is_tensor(x):
`True` if `x` is a tensor, otherwise `False`.
"""
return backend.core.is_tensor(x)


@keras_export("keras.ops.custom_gradient")
def custom_gradient(f):
"""Decorator to define a function with a custom gradient.
This decorator allows fine grained control over the gradients of a sequence
for operations. This may be useful for multiple reasons, including providing
a more efficient or numerically stable gradient for a sequence of
operations.
Note that `custom_gradient` only supports TensorFlow and JAX backends.
Args:
f: Function `f(*x)` that returns a tuple `(y, grad_fn)` where:
- `x` is a sequence of (nested structures of) tensor inputs to the
function.
- `y` is a (nested structure of) tensor outputs of applying
operations in `f` to `x`.
- `grad_fn` is a function with the signature `g(*grad_ys)` which
returns a list of tensors the same size as (flattened) `x`: the
derivatives of tensors in `y` with respect to the tensors in
`x`. `grad_ys` is a sequence of tensors the same size as
(flattened) `y` holding the initial value gradients for each
tensor in `y`.
Returns:
A function `h(x)` which returns the same value as `f(x)[0]` and whose
gradient is determined by `f(x)[1]`.
Example:
```python
@ops.custom_gradient
def log1pexp(x):
e = ops.exp(x)
def grad(upstream):
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
return ops.log(1 + e), grad
```
"""
return backend.core.custom_gradient(f)
38 changes: 38 additions & 0 deletions keras/ops/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,44 @@ def test_is_tensor(self):
self.assertTrue(ops.is_tensor(x))
self.assertFalse(ops.is_tensor([1, 2, 3]))

@pytest.mark.skipif(
backend.backend() not in ("tensorflow", "jax"),
reason=f"{backend.backend()} doesn't support `custom_gradient`.",
)
def test_custom_gradient(self):
@ops.custom_gradient
def log1pexp(x):
e = ops.exp(x)

def grad(upstream):
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))

return ops.log(1 + e), grad

def log1pexp_nan(x):
return ops.log(1 + ops.exp(x))

x = ops.convert_to_tensor(100.0)
if backend.backend() == "tensorflow":
import tensorflow as tf

with tf.GradientTape() as tape1:
tape1.watch(x)
y = log1pexp(x)
with tf.GradientTape() as tape2:
tape2.watch(x)
z = log1pexp_nan(x)
dy_dx = tape1.gradient(y, x)
dz_dx = tape2.gradient(z, x)
elif backend.backend() == "jax":
import jax

dy_dx = jax.grad(log1pexp)(x)
dz_dx = jax.grad(log1pexp_nan)(x)

self.assertEqual(ops.convert_to_numpy(dy_dx), 1.0)
self.assertTrue(ops.isnan(dz_dx))


class CoreOpsDtypeTest(testing.TestCase, parameterized.TestCase):
import jax # enable bfloat16 for numpy
Expand Down

0 comments on commit b97338e

Please sign in to comment.