From dce2fa1b7dcfd25d9573ce3186c5f6e8f79392bb Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:55:00 -0800 Subject: [PATCH] Fixed wrong PRNGKey annotation. --- equinox/nn/_batch_norm.py | 4 ++-- equinox/nn/_spectral_norm.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/equinox/nn/_batch_norm.py b/equinox/nn/_batch_norm.py index dca0353c..6363148d 100644 --- a/equinox/nn/_batch_norm.py +++ b/equinox/nn/_batch_norm.py @@ -4,7 +4,7 @@ import jax import jax.lax as lax import jax.numpy as jnp -from jaxtyping import Array, Bool, Float +from jaxtyping import Array, Bool, Float, PRNGKeyArray from .._module import field from ._sequential import StatefulLayer @@ -111,7 +111,7 @@ def __call__( x: Array, state: State, *, - key: Optional["jax.random.PRNGKey"] = None, # pyright: ignore + key: Optional[PRNGKeyArray] = None, inference: Optional[bool] = None, ) -> tuple[Array, State]: """**Arguments:** diff --git a/equinox/nn/_spectral_norm.py b/equinox/nn/_spectral_norm.py index 618df99f..d9f2cc48 100644 --- a/equinox/nn/_spectral_norm.py +++ b/equinox/nn/_spectral_norm.py @@ -111,7 +111,7 @@ def __call__( x: Array, state: State, *, - key: Optional["jax.random.PRNGKey"] = None, # pyright: ignore + key: Optional[PRNGKeyArray] = None, inference: Optional[bool] = None ) -> tuple[Array, State]: """**Arguments:**