Skip to content

Commit

Permalink
Use ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
inyutin committed Oct 26, 2024
1 parent fcf2757 commit c0cb00b
Show file tree
Hide file tree
Showing 10 changed files with 376 additions and 288 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements_ci.txt
- name: Test with ruff
run: |
ruff check .
ruff format --diff .
- name: Test with mypy
run: mypy -m aiohttp_retry

Expand Down
4 changes: 2 additions & 2 deletions aiohttp_retry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .client import * # noqa: F401, F403
from .retry_options import * # noqa: F401, F403
from .client import * # noqa: F403
from .retry_options import * # noqa: F403
165 changes: 85 additions & 80 deletions aiohttp_retry/client.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
from __future__ import annotations

import asyncio
import logging
import sys
from abc import abstractmethod
from collections.abc import Awaitable, Callable, Generator
from dataclasses import dataclass
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
)

Expand All @@ -23,30 +18,36 @@

from .retry_options import ExponentialRetry, RetryOptionsBase

_MIN_SERVER_ERROR_STATUS = 500

if TYPE_CHECKING:
from types import TracebackType

if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol
from typing import Protocol


class _Logger(Protocol):
"""
_Logger defines which methods logger object should have
"""
"""_Logger defines which methods logger object should have."""

@abstractmethod
def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: pass
def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
pass

@abstractmethod
def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: pass
def warning(self, msg: str, *args: Any, **kwargs: Any) -> None:
pass

@abstractmethod
def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: pass
def exception(self, msg: str, *args: Any, **kwargs: Any) -> None:
pass


# url itself or list of urls for changing between retries
_RAW_URL_TYPE = Union[StrOrURL, YARL_URL]
_URL_TYPE = Union[_RAW_URL_TYPE, List[_RAW_URL_TYPE], Tuple[_RAW_URL_TYPE, ...]]
_URL_TYPE = Union[_RAW_URL_TYPE, list[_RAW_URL_TYPE], tuple[_RAW_URL_TYPE, ...]]
_LoggerType = Union[_Logger, logging.Logger]

RequestFunc = Callable[..., Awaitable[ClientResponse]]
Expand All @@ -56,35 +57,35 @@ def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: pass
class RequestParams:
method: str
url: _RAW_URL_TYPE
headers: Optional[Dict[str, Any]] = None
trace_request_ctx: Optional[Dict[str, Any]] = None
kwargs: Optional[Dict[str, Any]] = None
headers: dict[str, Any] | None = None
trace_request_ctx: dict[str, Any] | None = None
kwargs: dict[str, Any] | None = None


class _RequestContext:
def __init__(
self,
request_func: RequestFunc,
params_list: List[RequestParams],
params_list: list[RequestParams],
logger: _LoggerType,
retry_options: RetryOptionsBase,
raise_for_status: bool = False,
) -> None:
assert len(params_list) > 0
assert len(params_list) > 0 # noqa: S101

self._request_func = request_func
self._params_list = params_list
self._logger = logger
self._retry_options = retry_options
self._raise_for_status = raise_for_status

self._response: Optional[ClientResponse] = None
self._response: ClientResponse | None = None

async def _is_skip_retry(self, current_attempt: int, response: ClientResponse) -> bool:
if current_attempt == self._retry_options.attempts:
return True

if response.status >= 500 and self._retry_options.retry_all_server_errors:
if response.status >= _MIN_SERVER_ERROR_STATUS and self._retry_options.retry_all_server_errors:
return False

if response.status in self._retry_options.statuses:
Expand Down Expand Up @@ -113,7 +114,7 @@ async def _do_request(self) -> ClientResponse:
params.url,
headers=params.headers,
trace_request_ctx={
'current_attempt': current_attempt,
"current_attempt": current_attempt,
**(params.trace_request_ctx or {}),
},
**(params.kwargs or {}),
Expand All @@ -127,18 +128,17 @@ async def _do_request(self) -> ClientResponse:
response.raise_for_status()
self._response = response
return self._response
else:
retry_wait = self._retry_options.get_timeout(attempt=current_attempt, response=response)
retry_wait = self._retry_options.get_timeout(attempt=current_attempt, response=response)

except Exception as e:
if current_attempt >= self._retry_options.attempts:
raise e
raise

is_exc_valid = any([isinstance(e, exc) for exc in self._retry_options.exceptions])
is_exc_valid = any(isinstance(e, exc) for exc in self._retry_options.exceptions)
if not is_exc_valid:
raise e
raise

debug_message = f"Retrying after exception: {repr(e)}"
debug_message = f"Retrying after exception: {e!r}"
retry_wait = self._retry_options.get_timeout(attempt=current_attempt, response=None)

self._logger.debug(debug_message)
Expand All @@ -152,38 +152,39 @@ async def __aenter__(self) -> ClientResponse:

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self._response is not None:
if not self._response.closed:
self._response.close()
if self._response is not None and not self._response.closed:
self._response.close()


def _url_to_urls(url: _URL_TYPE) -> Tuple[StrOrURL, ...]:
if isinstance(url, str) or isinstance(url, YARL_URL):
def _url_to_urls(url: _URL_TYPE) -> tuple[StrOrURL, ...]:
if isinstance(url, (str, YARL_URL)):
return (url,)

if isinstance(url, list):
urls = tuple(url)
elif isinstance(url, tuple):
urls = url
else:
raise ValueError("you can pass url only by str or list/tuple")
msg = "you can pass url only by str or list/tuple"
raise ValueError(msg) # noqa: TRY004

if len(urls) == 0:
raise ValueError("you can pass url by str or list/tuple with attempts count size")
msg = "you can pass url by str or list/tuple with attempts count size"
raise ValueError(msg)

return urls


class RetryClient:
def __init__(
self,
client_session: Optional[ClientSession] = None,
logger: Optional[_LoggerType] = None,
retry_options: Optional[RetryOptionsBase] = None,
client_session: ClientSession | None = None,
logger: _LoggerType | None = None,
retry_options: RetryOptionsBase | None = None,
raise_for_status: bool = False,
*args: Any,
**kwargs: Any,
Expand All @@ -208,9 +209,9 @@ def retry_options(self) -> RetryOptionsBase:

def requests(
self,
params_list: List[RequestParams],
retry_options: Optional[RetryOptionsBase] = None,
raise_for_status: Optional[bool] = None,
params_list: list[RequestParams],
retry_options: RetryOptionsBase | None = None,
raise_for_status: bool | None = None,
) -> _RequestContext:
return self._make_requests(
params_list=params_list,
Expand All @@ -222,8 +223,8 @@ def request(
self,
method: str,
url: StrOrURL,
retry_options: Optional[RetryOptionsBase] = None,
raise_for_status: Optional[bool] = None,
retry_options: RetryOptionsBase | None = None,
raise_for_status: bool | None = None,
**kwargs: Any,
) -> _RequestContext:
return self._make_request(
Expand All @@ -237,8 +238,8 @@ def request(
def get(
self,
url: _URL_TYPE,
retry_options: Optional[RetryOptionsBase] = None,
raise_for_status: Optional[bool] = None,
retry_options: RetryOptionsBase | None = None,
raise_for_status: bool | None = None,
**kwargs: Any,
) -> _RequestContext:
return self._make_request(
Expand All @@ -252,8 +253,8 @@ def get(
def options(
self,
url: _URL_TYPE,
retry_options: Optional[RetryOptionsBase] = None,
raise_for_status: Optional[bool] = None,
retry_options: RetryOptionsBase | None = None,
raise_for_status: bool | None = None,
**kwargs: Any,
) -> _RequestContext:
return self._make_request(
Expand All @@ -267,8 +268,9 @@ def options(
def head(
self,
url: _URL_TYPE,
retry_options: Optional[RetryOptionsBase] = None,
raise_for_status: Optional[bool] = None, **kwargs: Any,
retry_options: RetryOptionsBase | None = None,
raise_for_status: bool | None = None,
**kwargs: Any,
) -> _RequestContext:
return self._make_request(
method=hdrs.METH_HEAD,
Expand All @@ -281,8 +283,8 @@ def head(
def post(
self,
url: _URL_TYPE,
retry_options: Optional[RetryOptionsBase] = None,
raise_for_status: Optional[bool] = None,
retry_options: RetryOptionsBase | None = None,
raise_for_status: bool | None = None,
**kwargs: Any,
) -> _RequestContext:
return self._make_request(
Expand All @@ -296,8 +298,8 @@ def post(
def put(
self,
url: _URL_TYPE,
retry_options: Optional[RetryOptionsBase] = None,
raise_for_status: Optional[bool] = None,
retry_options: RetryOptionsBase | None = None,
raise_for_status: bool | None = None,
**kwargs: Any,
) -> _RequestContext:
return self._make_request(
Expand All @@ -311,8 +313,8 @@ def put(
def patch(
self,
url: _URL_TYPE,
retry_options: Optional[RetryOptionsBase] = None,
raise_for_status: Optional[bool] = None,
retry_options: RetryOptionsBase | None = None,
raise_for_status: bool | None = None,
**kwargs: Any,
) -> _RequestContext:
return self._make_request(
Expand All @@ -326,8 +328,8 @@ def patch(
def delete(
self,
url: _URL_TYPE,
retry_options: Optional[RetryOptionsBase] = None,
raise_for_status: Optional[bool] = None,
retry_options: RetryOptionsBase | None = None,
raise_for_status: bool | None = None,
**kwargs: Any,
) -> _RequestContext:
return self._make_request(
Expand All @@ -346,18 +348,21 @@ def _make_request(
self,
method: str,
url: _URL_TYPE,
retry_options: Optional[RetryOptionsBase] = None,
raise_for_status: Optional[bool] = None,
retry_options: RetryOptionsBase | None = None,
raise_for_status: bool | None = None,
**kwargs: Any,
) -> _RequestContext:
url_list = _url_to_urls(url)
params_list = [RequestParams(
method=method,
url=url,
headers=kwargs.pop('headers', {}),
trace_request_ctx=kwargs.pop('trace_request_ctx', None),
kwargs=kwargs,
) for url in url_list]
params_list = [
RequestParams(
method=method,
url=url,
headers=kwargs.pop("headers", {}),
trace_request_ctx=kwargs.pop("trace_request_ctx", None),
kwargs=kwargs,
)
for url in url_list
]

return self._make_requests(
params_list=params_list,
Expand All @@ -367,9 +372,9 @@ def _make_request(

def _make_requests(
self,
params_list: List[RequestParams],
retry_options: Optional[RetryOptionsBase] = None,
raise_for_status: Optional[bool] = None,
params_list: list[RequestParams],
retry_options: RetryOptionsBase | None = None,
raise_for_status: bool | None = None,
) -> _RequestContext:
if retry_options is None:
retry_options = self._retry_options
Expand All @@ -383,19 +388,19 @@ def _make_requests(
raise_for_status=raise_for_status,
)

async def __aenter__(self) -> 'RetryClient':
async def __aenter__(self) -> RetryClient: # noqa: PYI034
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()

def __del__(self) -> None:
if getattr(self, '_closed', None) is None:
if getattr(self, "_closed", None) is None:
# in case object was not initialized (__init__ raised an exception)
return

Expand Down
Loading

0 comments on commit c0cb00b

Please sign in to comment.