diff --git a/docs/changelog.rst b/docs/changelog.rst index dd71db47..af039789 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -24,6 +24,13 @@ Master - Fix websocket reconnection event. - Fix another websocket reconnect issue where it tried to decode nonexistent headers. +- ext.commands + - Additions + - Added support for the following typing constructs in command signatures: + - ``Union[A, B]`` / ``A | B`` + - ``Optional[T]`` / ``T | None`` + - ``Annotated[T, converter]`` (accessible through the ``typing_extensions`` module on older python versions) + 2.7.0 ====== diff --git a/twitchio/ext/commands/core.py b/twitchio/ext/commands/core.py index a596ba39..b0602d93 100644 --- a/twitchio/ext/commands/core.py +++ b/twitchio/ext/commands/core.py @@ -27,6 +27,7 @@ import itertools import copy +import types from typing import Any, Union, Optional, Callable, Awaitable, Tuple, TYPE_CHECKING, List, Type, Set, TypeVar from typing_extensions import Literal @@ -36,12 +37,32 @@ from . import builtin_converter if TYPE_CHECKING: + import sys + from twitchio import Message, Chatter, PartialChatter, Channel, User, PartialUser from . import Cog, Bot from .stringparser import StringParser + + if sys.version_info >= (3, 10): + UnionT = Union[types.UnionType, Union] + else: + UnionT = Union + + __all__ = ("Command", "command", "Group", "Context", "cooldown") +class EmptyArgumentSentinel: + def __repr__(self) -> str: + return "" + + def __eq__(self, __value: object) -> bool: + return False + + +EMPTY = EmptyArgumentSentinel() + + def _boolconverter(param: str): param = param.lower() if param in {"yes", "y", "1", "true", "on"}: @@ -114,40 +135,127 @@ def full_name(self) -> str: return self._name return f"{self.parent.full_name} {self._name}" - def _resolve_converter(self, converter: Union[Callable, Awaitable, type]) -> Union[Callable[..., Any]]: + def _is_optional_argument(self, converter: Any): + return (getattr(converter, "__origin__", None) is Union or isinstance(converter, types.UnionType)) and type( + None + ) in converter.__args__ + + def resolve_union_callback(self, name: str, converter: UnionT) -> Callable[[Context, str], Any]: + # print(type(converter), converter.__args__) + + args = converter.__args__ # type: ignore # pyright doesnt like this + + async def _resolve(context: Context, arg: str) -> Any: + t = EMPTY + last = None + + for original in args: + underlying = self._resolve_converter(name, original) + + try: + t: Any = underlying(context, arg) + if inspect.iscoroutine(t): + t = await t + + break + except Exception as l: + last = l + t = EMPTY # thisll get changed when t is a coroutine, but is still invalid, so roll it back + continue + + if t is EMPTY: + raise UnionArgumentParsingFailed(name, args) + + return t + + return _resolve + + def resolve_optional_callback(self, name: str, converter: Any) -> Callable[[Context, str], Any]: + underlying = self._resolve_converter(name, converter.__args__[0]) + + async def _resolve(context: Context, arg: str) -> Any: + try: + t: Any = underlying(context, arg) + if inspect.iscoroutine(t): + t = await t + + except Exception: + return EMPTY # instruct the parser to roll back and ignore this argument + + return t + + return _resolve + + def _resolve_converter(self, name: str, converter: Union[Callable, Awaitable, type]) -> Callable[..., Any]: if ( isinstance(converter, type) and converter.__module__.startswith("twitchio") and converter in builtin_converter._mapping ): - return builtin_converter._mapping[converter] - return converter + return self._convert_builtin_type(name, converter, builtin_converter._mapping[converter]) + + elif converter is bool: + converter = self._convert_builtin_type(name, bool, _boolconverter) + + elif converter in (str, int): + converter = self._convert_builtin_type(name, converter, converter) # type: ignore + + elif self._is_optional_argument(converter): + return self.resolve_optional_callback(name, converter) + + elif isinstance(converter, types.UnionType) or getattr(converter, "__origin__", None) is Union: + return self.resolve_union_callback(name, converter) # type: ignore + + elif hasattr(converter, "__metadata__"): # Annotated + annotated = converter.__metadata__ # type: ignore + return self._resolve_converter(name, annotated[0]) + + return converter # type: ignore + + def _convert_builtin_type( + self, arg_name: str, original: type, converter: Union[Callable[[str], Any], Callable[[str], Awaitable[Any]]] + ) -> Callable[[Context, str], Awaitable[Any]]: + async def resolve(_, arg: str) -> Any: + try: + t = converter(arg) + + if inspect.iscoroutine(t): + t = await t + + return t + except Exception as e: + raise ArgumentParsingFailed( + f"Failed to convert `{arg}` to expected type {original.__name__} for argument `{arg_name}`", + original=e, + argname=arg_name, + expected=original, + ) from e + + return resolve async def _convert_types(self, context: Context, param: inspect.Parameter, parsed: str) -> Any: converter = param.annotation + if converter is param.empty: if param.default in (param.empty, None): converter = str else: converter = type(param.default) - true_converter = self._resolve_converter(converter) + + true_converter = self._resolve_converter(param.name, converter) try: - if true_converter in (int, str): - argument = true_converter(parsed) - elif true_converter is bool: - argument = _boolconverter(parsed) - else: - argument = true_converter(context, parsed) + argument = true_converter(context, parsed) if inspect.iscoroutine(argument): argument = await argument - except BadArgument: + except BadArgument as e: + if e.name is None: + e.name = param.name + raise except Exception as e: raise ArgumentParsingFailed( - f"Invalid argument parsed at `{param.name}` in command `{self.name}`." - f" Expected type {converter} got {type(parsed)}.", - e, + f"Failed to parse `{parsed}` for argument {param.name}", original=e, argname=param.name, expected=None ) from e return argument @@ -170,12 +278,26 @@ async def parse_args(self, context: Context, instance: Optional[Cog], parsed: di try: argument = parsed.pop(index) except (KeyError, IndexError): + if self._is_optional_argument(param.annotation): # parameter is optional and at the end. + args.append(param.default if param.default is not param.empty else None) + continue + if param.default is param.empty: - raise MissingRequiredArgument(param) + raise MissingRequiredArgument(argname=param.name) + args.append(param.default) else: - argument = await self._convert_types(context, param, argument) - args.append(argument) + _parsed_arg = await self._convert_types(context, param, argument) + + if _parsed_arg is EMPTY: + parsed[index] = argument + index -= 1 + args.append(param.default if param.default is not param.empty else None) + + continue + else: + args.append(_parsed_arg) + elif param.kind == param.KEYWORD_ONLY: rest = " ".join(parsed.values()) if rest.startswith(" "): @@ -183,13 +305,13 @@ async def parse_args(self, context: Context, instance: Optional[Cog], parsed: di if rest: rest = await self._convert_types(context, param, rest) elif param.default is param.empty: - raise MissingRequiredArgument(param) + raise MissingRequiredArgument(argname=param.name) else: rest = param.default kwargs[param.name] = rest parsed.clear() break - elif param.VAR_POSITIONAL: + elif param.kind == param.VAR_POSITIONAL: args.extend([await self._convert_types(context, param, argument) for argument in parsed.values()]) parsed.clear() break diff --git a/twitchio/ext/commands/errors.py b/twitchio/ext/commands/errors.py index 04ba0ab9..eeaa2229 100644 --- a/twitchio/ext/commands/errors.py +++ b/twitchio/ext/commands/errors.py @@ -21,6 +21,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from .core import Command class TwitchCommandError(Exception): @@ -38,29 +44,52 @@ class InvalidCog(TwitchCommandError): class MissingRequiredArgument(TwitchCommandError): - pass + def __init__(self, *args, argname: Optional[str] = None) -> None: + self.name: str = argname or "unknown" + + if args: + super().__init__(*args) + else: + super().__init__(f"Missing required argument `{self.name}`") class BadArgument(TwitchCommandError): - def __init__(self, message: str): + def __init__(self, message: str, argname: Optional[str] = None): + self.name: str = argname # type: ignore # this'll get fixed in the parser handler self.message = message super().__init__(message) class ArgumentParsingFailed(BadArgument): - def __init__(self, message: str, original: Exception): - self.original = original - super().__init__(message) + def __init__( + self, message: str, original: Exception, argname: Optional[str] = None, expected: Optional[type] = None + ): + self.original: Exception = original + self.name: str = argname # type: ignore # in theory this'll never be None but if someone is creating this themselves itll be none. + self.expected_type: Optional[type] = expected + + Exception.__init__(self, message) # bypass badArgument + + +class UnionArgumentParsingFailed(ArgumentParsingFailed): + def __init__(self, argname: str, expected: tuple[type, ...]): + self.name: str = argname + self.expected_type: tuple[type, ...] = expected + + self.message = f"Failed to convert argument `{self.name}` to any of the valid options" + Exception.__init__(self, self.message) class CommandNotFound(TwitchCommandError): - pass + def __init__(self, message: str, name: str) -> None: + self.name: str = name + super().__init__(message) class CommandOnCooldown(TwitchCommandError): - def __init__(self, command, retry_after): - self.command = command - self.retry_after = retry_after + def __init__(self, command: Command, retry_after: float): + self.command: Command = command + self.retry_after: float = retry_after super().__init__(f"Command <{command.name}> is on cooldown. Try again in ({retry_after:.2f})s")