Skip to content

Commit

Permalink
[commands] command typehint updates (#425)
Browse files Browse the repository at this point in the history
* Add ability to parse Unions into ext.commands

* pipe was 3.10, not 3.8

* Fix optional parsing when optional is the last argument

* Add Annotated support

* Run black

* revamp errors with proper useful messages and details on objects

* update changelog with changes

* run black
  • Loading branch information
IAmTomahawkx authored Sep 19, 2023
1 parent 9628522 commit 4ae7f4a
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 28 deletions.
7 changes: 7 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ Master
- Fix websocket reconnection event.
- Fix another websocket reconnect issue where it tried to decode nonexistent headers.

- ext.commands
- Additions
- Added support for the following typing constructs in command signatures:
- ``Union[A, B]`` / ``A | B``
- ``Optional[T]`` / ``T | None``
- ``Annotated[T, converter]`` (accessible through the ``typing_extensions`` module on older python versions)


2.7.0
======
Expand Down
160 changes: 141 additions & 19 deletions twitchio/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import itertools
import copy
import types
from typing import Any, Union, Optional, Callable, Awaitable, Tuple, TYPE_CHECKING, List, Type, Set, TypeVar
from typing_extensions import Literal

Expand All @@ -36,12 +37,32 @@
from . import builtin_converter

if TYPE_CHECKING:
import sys

from twitchio import Message, Chatter, PartialChatter, Channel, User, PartialUser
from . import Cog, Bot
from .stringparser import StringParser

if sys.version_info >= (3, 10):
UnionT = Union[types.UnionType, Union]
else:
UnionT = Union


__all__ = ("Command", "command", "Group", "Context", "cooldown")


class EmptyArgumentSentinel:
def __repr__(self) -> str:
return "<EMPTY>"

def __eq__(self, __value: object) -> bool:
return False


EMPTY = EmptyArgumentSentinel()


def _boolconverter(param: str):
param = param.lower()
if param in {"yes", "y", "1", "true", "on"}:
Expand Down Expand Up @@ -114,40 +135,127 @@ def full_name(self) -> str:
return self._name
return f"{self.parent.full_name} {self._name}"

def _resolve_converter(self, converter: Union[Callable, Awaitable, type]) -> Union[Callable[..., Any]]:
def _is_optional_argument(self, converter: Any):
return (getattr(converter, "__origin__", None) is Union or isinstance(converter, types.UnionType)) and type(
None
) in converter.__args__

def resolve_union_callback(self, name: str, converter: UnionT) -> Callable[[Context, str], Any]:
# print(type(converter), converter.__args__)

args = converter.__args__ # type: ignore # pyright doesnt like this

async def _resolve(context: Context, arg: str) -> Any:
t = EMPTY
last = None

for original in args:
underlying = self._resolve_converter(name, original)

try:
t: Any = underlying(context, arg)
if inspect.iscoroutine(t):
t = await t

break
except Exception as l:
last = l
t = EMPTY # thisll get changed when t is a coroutine, but is still invalid, so roll it back
continue

if t is EMPTY:
raise UnionArgumentParsingFailed(name, args)

return t

return _resolve

def resolve_optional_callback(self, name: str, converter: Any) -> Callable[[Context, str], Any]:
underlying = self._resolve_converter(name, converter.__args__[0])

async def _resolve(context: Context, arg: str) -> Any:
try:
t: Any = underlying(context, arg)
if inspect.iscoroutine(t):
t = await t

except Exception:
return EMPTY # instruct the parser to roll back and ignore this argument

return t

return _resolve

def _resolve_converter(self, name: str, converter: Union[Callable, Awaitable, type]) -> Callable[..., Any]:
if (
isinstance(converter, type)
and converter.__module__.startswith("twitchio")
and converter in builtin_converter._mapping
):
return builtin_converter._mapping[converter]
return converter
return self._convert_builtin_type(name, converter, builtin_converter._mapping[converter])

elif converter is bool:
converter = self._convert_builtin_type(name, bool, _boolconverter)

elif converter in (str, int):
converter = self._convert_builtin_type(name, converter, converter) # type: ignore

elif self._is_optional_argument(converter):
return self.resolve_optional_callback(name, converter)

elif isinstance(converter, types.UnionType) or getattr(converter, "__origin__", None) is Union:
return self.resolve_union_callback(name, converter) # type: ignore

elif hasattr(converter, "__metadata__"): # Annotated
annotated = converter.__metadata__ # type: ignore
return self._resolve_converter(name, annotated[0])

return converter # type: ignore

def _convert_builtin_type(
self, arg_name: str, original: type, converter: Union[Callable[[str], Any], Callable[[str], Awaitable[Any]]]
) -> Callable[[Context, str], Awaitable[Any]]:
async def resolve(_, arg: str) -> Any:
try:
t = converter(arg)

if inspect.iscoroutine(t):
t = await t

return t
except Exception as e:
raise ArgumentParsingFailed(
f"Failed to convert `{arg}` to expected type {original.__name__} for argument `{arg_name}`",
original=e,
argname=arg_name,
expected=original,
) from e

return resolve

async def _convert_types(self, context: Context, param: inspect.Parameter, parsed: str) -> Any:
converter = param.annotation

if converter is param.empty:
if param.default in (param.empty, None):
converter = str
else:
converter = type(param.default)
true_converter = self._resolve_converter(converter)

true_converter = self._resolve_converter(param.name, converter)

try:
if true_converter in (int, str):
argument = true_converter(parsed)
elif true_converter is bool:
argument = _boolconverter(parsed)
else:
argument = true_converter(context, parsed)
argument = true_converter(context, parsed)
if inspect.iscoroutine(argument):
argument = await argument
except BadArgument:
except BadArgument as e:
if e.name is None:
e.name = param.name

raise
except Exception as e:
raise ArgumentParsingFailed(
f"Invalid argument parsed at `{param.name}` in command `{self.name}`."
f" Expected type {converter} got {type(parsed)}.",
e,
f"Failed to parse `{parsed}` for argument {param.name}", original=e, argname=param.name, expected=None
) from e
return argument

Expand All @@ -170,26 +278,40 @@ async def parse_args(self, context: Context, instance: Optional[Cog], parsed: di
try:
argument = parsed.pop(index)
except (KeyError, IndexError):
if self._is_optional_argument(param.annotation): # parameter is optional and at the end.
args.append(param.default if param.default is not param.empty else None)
continue

if param.default is param.empty:
raise MissingRequiredArgument(param)
raise MissingRequiredArgument(argname=param.name)

args.append(param.default)
else:
argument = await self._convert_types(context, param, argument)
args.append(argument)
_parsed_arg = await self._convert_types(context, param, argument)

if _parsed_arg is EMPTY:
parsed[index] = argument
index -= 1
args.append(param.default if param.default is not param.empty else None)

continue
else:
args.append(_parsed_arg)

elif param.kind == param.KEYWORD_ONLY:
rest = " ".join(parsed.values())
if rest.startswith(" "):
rest = rest.lstrip(" ")
if rest:
rest = await self._convert_types(context, param, rest)
elif param.default is param.empty:
raise MissingRequiredArgument(param)
raise MissingRequiredArgument(argname=param.name)
else:
rest = param.default
kwargs[param.name] = rest
parsed.clear()
break
elif param.VAR_POSITIONAL:
elif param.kind == param.VAR_POSITIONAL:
args.extend([await self._convert_types(context, param, argument) for argument in parsed.values()])
parsed.clear()
break
Expand Down
47 changes: 38 additions & 9 deletions twitchio/ext/commands/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations

from typing import Optional, TYPE_CHECKING

if TYPE_CHECKING:
from .core import Command


class TwitchCommandError(Exception):
Expand All @@ -38,29 +44,52 @@ class InvalidCog(TwitchCommandError):


class MissingRequiredArgument(TwitchCommandError):
pass
def __init__(self, *args, argname: Optional[str] = None) -> None:
self.name: str = argname or "unknown"

if args:
super().__init__(*args)
else:
super().__init__(f"Missing required argument `{self.name}`")


class BadArgument(TwitchCommandError):
def __init__(self, message: str):
def __init__(self, message: str, argname: Optional[str] = None):
self.name: str = argname # type: ignore # this'll get fixed in the parser handler
self.message = message
super().__init__(message)


class ArgumentParsingFailed(BadArgument):
def __init__(self, message: str, original: Exception):
self.original = original
super().__init__(message)
def __init__(
self, message: str, original: Exception, argname: Optional[str] = None, expected: Optional[type] = None
):
self.original: Exception = original
self.name: str = argname # type: ignore # in theory this'll never be None but if someone is creating this themselves itll be none.
self.expected_type: Optional[type] = expected

Exception.__init__(self, message) # bypass badArgument


class UnionArgumentParsingFailed(ArgumentParsingFailed):
def __init__(self, argname: str, expected: tuple[type, ...]):
self.name: str = argname
self.expected_type: tuple[type, ...] = expected

self.message = f"Failed to convert argument `{self.name}` to any of the valid options"
Exception.__init__(self, self.message)


class CommandNotFound(TwitchCommandError):
pass
def __init__(self, message: str, name: str) -> None:
self.name: str = name
super().__init__(message)


class CommandOnCooldown(TwitchCommandError):
def __init__(self, command, retry_after):
self.command = command
self.retry_after = retry_after
def __init__(self, command: Command, retry_after: float):
self.command: Command = command
self.retry_after: float = retry_after
super().__init__(f"Command <{command.name}> is on cooldown. Try again in ({retry_after:.2f})s")


Expand Down

0 comments on commit 4ae7f4a

Please sign in to comment.