Skip to content

Commit

Permalink
fix: pyright :)
Browse files Browse the repository at this point in the history
  • Loading branch information
knyazer committed Sep 7, 2024
1 parent ec57ad5 commit ec3bb80
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
4 changes: 2 additions & 2 deletions equinox/_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def __len__(cls) -> int: ...
class Enumeration( # pyright: ignore
enum.Enum, EnumerationItem, metaclass=_Sequence
):
_name_to_item: ClassVar[dict[str, EnumerationItem]]
_index_to_message: ClassVar[list[str]]
_name_to_item: dict[str, EnumerationItem]
_index_to_message: list[str]
_base_offsets: ClassVar[dict["Enumeration", int]]

@classmethod
Expand Down
27 changes: 13 additions & 14 deletions equinox/nn/_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,23 @@ def _padding_init(
num_spatial_dims: int,
) -> Union[str, tuple[tuple[int, int], ...]]:
if isinstance(padding, str):
padding = padding.upper()
if padding not in ("SAME", "SAME_LOWER", "VALID"):
raise ValueError(
"`padding` string must be `'SAME'`, `'SAME_LOWER'`, or `'VALID'`."
)
res = padding.upper()
if res in ("SAME", "SAME_LOWER", "VALID"):
return res
raise ValueError(
"`padding` string must be `'SAME'`, `'SAME_LOWER'`, or `'VALID'`."
)
elif isinstance(padding, int):
padding = tuple((padding, padding) for _ in range(num_spatial_dims))
return tuple((padding, padding) for _ in range(num_spatial_dims))
elif isinstance(padding, Sequence) and len(padding) == num_spatial_dims:
if all_sequences(padding):
padding = tuple(padding)
return tuple(padding) # type: ignore - bug in pyright, infers as tuple[int, ...]
else:
padding = tuple((p, p) for p in padding)
else:
raise ValueError(
"`padding` must either be a string, an int, or tuple of length "
f"{num_spatial_dims} containing ints or tuples of length 2."
)
return padding
return tuple((p, p) for p in padding)
raise ValueError(
"`padding` must either be a string, an int, or tuple of length "
f"{num_spatial_dims} containing ints or tuples of length 2."
)


def _padding_mode_init(padding_mode: str) -> str:
Expand Down

0 comments on commit ec3bb80

Please sign in to comment.