Skip to content

Commit

Permalink
Fixed wrong PRNGKey annotation.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Nov 27, 2023
1 parent 588a8c3 commit dce2fa1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions equinox/nn/_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import Array, Bool, Float
from jaxtyping import Array, Bool, Float, PRNGKeyArray

from .._module import field
from ._sequential import StatefulLayer
Expand Down Expand Up @@ -111,7 +111,7 @@ def __call__(
x: Array,
state: State,
*,
key: Optional["jax.random.PRNGKey"] = None, # pyright: ignore
key: Optional[PRNGKeyArray] = None,
inference: Optional[bool] = None,
) -> tuple[Array, State]:
"""**Arguments:**
Expand Down
2 changes: 1 addition & 1 deletion equinox/nn/_spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __call__(
x: Array,
state: State,
*,
key: Optional["jax.random.PRNGKey"] = None, # pyright: ignore
key: Optional[PRNGKeyArray] = None,
inference: Optional[bool] = None
) -> tuple[Array, State]:
"""**Arguments:**
Expand Down

0 comments on commit dce2fa1

Please sign in to comment.