Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exact spectral norm feature #831

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 97 additions & 24 deletions equinox/nn/_spectral_norm.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from typing import Generic, Optional, TypeVar
from typing import Callable, Generic, Optional, TypeVar

lockwo marked this conversation as resolved.
Show resolved Hide resolved
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Array, Float, PRNGKeyArray

from .._eval_shape import filter_eval_shape
from .._module import field
from .._tree import tree_at
from ._sequential import StatefulLayer
from ._stateful import State, StateIndex


def _power_iteration(weight, u, v, eps):
def _power_iteration_old(weight, u, v, eps):
u = weight @ v
lockwo marked this conversation as resolved.
Show resolved Hide resolved
u_norm = jnp.sqrt(jnp.sum(u**2))
u = u / jnp.maximum(eps, u_norm)
Expand All @@ -24,6 +25,18 @@ def _power_iteration(weight, u, v, eps):
return u, v


def _power_iteration(forward, transpose, v_prev, eps):
u = forward(v_prev)
u_norm = jnp.sqrt(jnp.sum(u**2))
u = u / jnp.maximum(eps, u_norm)

v = transpose(u)[0]
v_norm = jnp.sqrt(jnp.sum(v**2))
v = v / jnp.maximum(eps, v_norm)

return u, v


_Layer = TypeVar("_Layer")


Expand All @@ -42,6 +55,12 @@ class SpectralNorm(StatefulLayer, Generic[_Layer], strict=True):
[Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957)
for more details and motivation.

Default approaches to spectral normalization rely on inaccurate approximations to the
spectral norm, although it often perform better; see
[Why Spectral Normalization Stabilizes GANs: Analysis and Improvements](https://arxiv.org/abs/2009.02773),
and [Generalizable Adversarial Training via Spectral Normalization](https://arxiv.org/abs/1811.07457).
Equinox offers functionality for both exact and approximate spectral norms.

!!! example

See [this example](../../examples/stateful.ipynb) for example usage.
Expand All @@ -53,6 +72,8 @@ class SpectralNorm(StatefulLayer, Generic[_Layer], strict=True):
""" # noqa: E501

layer: _Layer
reverse: Optional[Callable]
exact: bool
weight_name: str = field(static=True)
uv_index: StateIndex[tuple[Float[Array, " u_size"], Float[Array, " v_size"]]]
num_power_iterations: int = field(static=True)
Expand All @@ -66,6 +87,8 @@ def __init__(
num_power_iterations: int = 1,
eps: float = 1e-12,
inference: bool = False,
exact: bool = False,
input_shape: Optional[jax.ShapeDtypeStruct] = None,
*,
key: PRNGKeyArray,
):
Expand All @@ -81,6 +104,11 @@ def __init__(
- `inference`: Whether this is in inference mode, at which time no power
iterations are performed. This may be toggled with
[`equinox.nn.inference_mode`][].
- `exact`: Whether or not to compute the exact linear transpose for power series
iteration. Traditional approaches rely on reshaping >2D linear operators,
rather than doing the linear transpose in >2D.
- `input_shape`: If `exact` is true, the input structure to the layer must be
specified
- `key`: A `jax.random.PRNGKey` used to provide randomness for initialisation.
(Keyword only argument.)

Expand All @@ -90,6 +118,13 @@ def __init__(
The `dtype` of the weight array of the `layer` input is applied to all
parameters in this layer.


!!! Caution

If `exact` is true, it computes the transpose via `jax.linear_transpose` of
the layer. This includes all operations of the layer call, which means for
layers with a bias, this can result in the incorrect spectral value.

Comment on lines +108 to +112
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should follow JAX's lead here and transpose the tangent pass of jax.jvp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I tried to implemented what I thought you meant. This also means we could remove the "weight" flag for the exact case (maybe?) since we basically "determine" the weight through the jvp?

Let me know if this is what you had in mind, or if I was totally off. Does seem like a lot of jvps tho.

"""
self.layer = layer
self.weight_name = weight_name
Expand All @@ -98,17 +133,32 @@ def __init__(
self.inference = inference

weight = getattr(layer, weight_name)
if weight.ndim < 2:
raise ValueError("`weight` must be at least two-dimensional")
weight = jnp.reshape(weight, (weight.shape[0], -1))
dtype = weight.dtype
u_len, v_len = weight.shape
ukey, vkey = jr.split(key)
u0 = jr.normal(ukey, (u_len,), dtype=dtype)
v0 = jr.normal(vkey, (v_len,), dtype=dtype)
for _ in range(15):
u0, v0 = _power_iteration(weight, u0, v0, eps)

if not exact:
if weight.ndim < 2:
lockwo marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("`weight` must be at least two-dimensional")
weight = jnp.reshape(weight, (weight.shape[0], -1))
dtype = weight.dtype
u_len, v_len = weight.shape
self.reverse = None
u0 = jr.normal(ukey, (u_len,), dtype=dtype)
v0 = jr.normal(vkey, (v_len,), dtype=dtype)
for _ in range(15):
u0, v0 = _power_iteration_old(weight, u0, v0, eps)
else:
assert (
input_shape is not None
), "Must specify `input_shape` to use exact spectral norm!"
assert isinstance(self.layer, Callable)
u_shape = filter_eval_shape(self.layer, input_shape)
lockwo marked this conversation as resolved.
Show resolved Hide resolved
u0 = jr.normal(ukey, u_shape.shape, dtype=u_shape.dtype)
v0 = jr.normal(vkey, input_shape.shape, dtype=input_shape.dtype)
self.reverse = jax.linear_transpose(self.layer, input_shape)
for _ in range(15):
u0, v0 = _power_iteration(self.layer, self.reverse, v0, self.eps)
self.uv_index = StateIndex((u0, v0))
self.exact = exact

@jax.named_scope("eqx.nn.SpectralNorm")
def __call__(
Expand Down Expand Up @@ -141,17 +191,40 @@ def __call__(

u, v = state.get(self.uv_index)
weight = getattr(self.layer, self.weight_name)
weight_shape = weight.shape
weight = jnp.reshape(weight, (weight.shape[0], -1))
if inference is None:
inference = self.inference
if not inference:
stop_weight = lax.stop_gradient(weight)
for _ in range(self.num_power_iterations):
u, v = _power_iteration(stop_weight, u, v, self.eps)
state = state.set(self.uv_index, (u, v))
σ = jnp.einsum("i,ij,j->", u, weight, v)
σ_weight = jnp.reshape(weight / σ, weight_shape)
layer = tree_at(lambda l: getattr(l, self.weight_name), self.layer, σ_weight)
out = layer(x)
if not self.exact:
weight_shape = weight.shape
weight = jnp.reshape(weight, (weight.shape[0], -1))
if inference is None:
inference = self.inference
if not inference:
stop_weight = lax.stop_gradient(weight)
for _ in range(self.num_power_iterations):
u, v = _power_iteration_old(stop_weight, u, v, self.eps)
state = state.set(self.uv_index, (u, v))
σ = jnp.einsum("i,ij,j->", u, weight, v)
σ_weight = jnp.reshape(weight / σ, weight_shape)
layer = tree_at(
lambda l: getattr(l, self.weight_name), self.layer, σ_weight
)
out = layer(x)
else:
if inference is None:
inference = self.inference
if not inference:
stop_weight = lax.stop_gradient(weight)
layer = tree_at(
lambda l: getattr(l, self.weight_name), self.layer, stop_weight
)
for _ in range(self.num_power_iterations):
u, v = _power_iteration(layer, self.reverse, v, self.eps)
state = state.set(self.uv_index, (u, v))
else:
layer = self.layer
assert callable(layer)
σ = jnp.sum(u * layer(v))
σ_weight = weight / σ
layer = tree_at(
lambda l: getattr(l, self.weight_name), self.layer, σ_weight
)
out = layer(x)
return out, state
46 changes: 46 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,52 @@ def λ1():
assert out.shape == (4, 6, 6, 6)


def test_spectral_norm_exact(getkey):
def λ1():
u, v = state.get(spectral.uv_index)
σ = jnp.sum(u * spectral.layer(v))
_, s, _ = jnp.linalg.svd(spectral.layer.weight / σ) # pyright: ignore
return s[0]

x = jrandom.normal(getkey(), (5,))
spectral = eqx.nn.SpectralNorm(
eqx.nn.Linear(5, 6, key=getkey(), use_bias=False),
"weight",
exact=True,
input_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
key=getkey(),
)
state = eqx.nn.State(spectral)
for _ in range(200):
_, state = spectral(x, state)
assert jnp.allclose(λ1(), 1)

# Test not updated at inference time
spectral = eqx.tree_at(
lambda s: s.layer.weight, spectral, spectral.layer.weight + 1
)
spectral = eqx.nn.inference_mode(spectral, value=True)
assert not jnp.allclose(λ1(), 1)
for _ in range(100):
_, state = spectral(x, state)
assert not jnp.allclose(λ1(), 1)

# Test >2 dimensional input

x = jrandom.normal(getkey(), (5, 8, 8, 8))
conv = eqx.nn.Conv3d(5, 4, 3, key=getkey(), use_bias=False)
spectral = eqx.nn.SpectralNorm(
conv,
"weight",
exact=True,
input_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
key=getkey(),
)
state = eqx.nn.State(spectral)
out, _ = spectral(x, state)
assert out.shape == (4, 6, 6, 6)


def test_weight_norm(getkey):
# Linear
linear = eqx.nn.Linear(4, 4, key=getkey())
Expand Down
Loading