Skip to content

Commit

Permalink
Add pre/post refresh callbacks
Browse files Browse the repository at this point in the history
(cherry picked from commit bb0e002)
  • Loading branch information
bboe authored and LilSpazJoekp committed Feb 24, 2021
1 parent 3677c7c commit 496bc2c
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 73 deletions.
17 changes: 17 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,23 @@ Change Log
asyncprawcore follows `semantic versioning <http://semver.org/>`_ with the exception
that deprecations will not be announced by a minor release.

Unreleased
----------

**Added**

* ``Authorizer`` optionally takes a ``pre_refresh_callback`` keyword
argument. If provided, the function will called with the instance of
``Authorizer`` prior to refreshing the access and refresh tokens.
* ``Authorizer`` optionally takes a ``post_refresh_callback`` keyword
argument. If provided, the function will called with the instance of
``Authorizer`` after refreshing the access and refresh tokens.

**Changed**

* The ``refresh_token`` argument to ``Authorizer`` must now be passed by
keyword, and cannot be passed as a positional argument.

1.5.1 (2021-01-25)
------------------

Expand Down
30 changes: 29 additions & 1 deletion asyncprawcore/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Provides Authentication and Authorization classes."""
import aiohttp
import inspect
import time

from aiohttp import ClientRequest
Expand Down Expand Up @@ -194,16 +195,35 @@ class Authorizer(BaseAuthorizer):

AUTHENTICATOR_CLASS = BaseAuthenticator

def __init__(self, authenticator, refresh_token=None):
def __init__(
self,
authenticator,
*,
post_refresh_callback=None,
pre_refresh_callback=None,
refresh_token=None,
):
"""Represent a single authorization to Reddit's API.
:param authenticator: An instance of a subclass of
:class:`BaseAuthenticator`.
:param post_refresh_callback: (Optional) When a single-argument synchronous or
asynchronous function is passed, the function will be called prior to
refreshing the access and refresh tokens. The argument to the callback is
the :class:`Authorizer` instance. This callback can be used to inspect and
modify the attributes of the :class:`Authorizer`.
:param pre_refresh_callback: (Optional) When a single-argument function
synchronous or asynchronous is passed, the function will be called after
refreshing the access and refresh tokens. The argument to the callback is
the :class:`Authorizer` instance. This callback can be used to inspect and
modify the attributes of the :class:`Authorizer`.
:param refresh_token: (Optional) Enables the ability to refresh the
authorization.
"""
super(Authorizer, self).__init__(authenticator)
self._post_refresh_callback = post_refresh_callback
self._pre_refresh_callback = pre_refresh_callback
self.refresh_token = refresh_token

async def authorize(self, code):
Expand All @@ -223,11 +243,19 @@ async def authorize(self, code):

async def refresh(self):
"""Obtain a new access token from the refresh_token."""
if self._pre_refresh_callback:
result = self._pre_refresh_callback(self)
if inspect.isawaitable(result):
await self._pre_refresh_callback(self)
if self.refresh_token is None:
raise InvalidInvocation("refresh token not provided")
await self._request_token(
grant_type="refresh_token", refresh_token=self.refresh_token
)
if self._post_refresh_callback:
result = self._post_refresh_callback(self)
if inspect.isawaitable(result):
await self._post_refresh_callback(self)

async def revoke(self, only_access=False):
"""Revoke the current Authorization.
Expand Down
121 changes: 58 additions & 63 deletions tests/cassettes/Authorizer_refresh.json
Original file line number Diff line number Diff line change
@@ -1,64 +1,59 @@
{
"version": 1,
"interactions": [
{
"request": {
"method": "POST",
"uri": "https://www.reddit.com/api/v1/access_token",
"body": [
[
"grant_type",
"refresh_token"
],
[
"refresh_token",
"fake_refresh_token"
]
],
"headers": {
"User-Agent": [
"asyncprawcore:test (by /u/bboe) asyncprawcore/1.4.0"
],
"Connection": [
"close"
],
"AUTHORIZATION": [
"Basic <placeholder_auth>"
]
}
},
"response": {
"status": {
"code": 200,
"message": "OK"
},
"headers": {
"Connection": "close",
"Content-Length": "40",
"Content-Type": "application/json; charset=UTF-8",
"x-frame-options": "SAMEORIGIN",
"x-content-type-options": "nosniff",
"x-xss-protection": "1; mode=block",
"x-ua-compatible": "IE=edge",
"set-cookie": "session_tracker=GTiELnoAfGihqJgp4R.0.1592045156266.Z0FBQUFBQmU1SzVra09KQmZhamhzMUh1aXdZMFNsTWJCTHJBdFR1dE9QLUdZTmpvQ1EzVXlXNmdSR3ZUQXh2dFd1MTh5Mzc2UXFleTZZV2dxWWREY0FLSU8xMWotd1JRWlpHNjA5TDl3b3VEQjNlUm5JbXVFVUdNaHVnZU5NNUNqR0o0QzFZNzRoMm0; Domain=reddit.com; Max-Age=7199; Path=/; expires=Sat, 13-Jun-2020 12:45:56 GMT; secure",
"cache-control": "max-age=0, must-revalidate",
"X-Moose": "majestic",
"Accept-Ranges": "bytes",
"Date": "Sat, 13 Jun 2020 10:45:56 GMT",
"Via": "1.1 varnish",
"X-Served-By": "cache-ams21024-AMS",
"X-Cache": "MISS",
"X-Cache-Hits": "0",
"X-Timer": "S1592045156.171554,VS0,VE141",
"Set-Cookie": "edgebucket=jg5MWvCfB9iCg5xoTg; Domain=reddit.com; Max-Age=63071999; Path=/; secure",
"Strict-Transport-Security": "max-age=15552000; includeSubDomains; preload",
"Server": "snooserv"
},
"body": {
"string": "{\"access_token\": \"fake_access_Token\", \"token_type\": \"bearer\", \"expires_in\": 3600, \"scope\": \"submit\"}"
},
"url": "https://www.reddit.com/api/v1/access_token"
}
}
]
}
"interactions": [
{
"request": {
"body": [
[
"grant_type",
"refresh_token"
],
[
"refresh_token",
"<REFRESH_TOKEN>"
]
],
"headers": {
"AUTHORIZATION": [
"Basic <BASIC_AUTH>"
],
"Connection": [
"close"
],
"User-Agent": [
"asyncprawcore:test (by /u/bboe) asyncprawcore/1.5.1"
]
},
"method": "POST",
"uri": "https://www.reddit.com/api/v1/access_token"
},
"response": {
"body": {
"string": "{\"access_token\": \"<ACCESS_TOKEN>\", \"token_type\": \"bearer\", \"expires_in\": 3600, \"refresh_token\": \"aaaaaaa-0000000000000000000000-aaaaaaa\", \"scope\": \"account creddits edit flair history identity livemanage modconfig modcontributors modflair modlog modmail modothers modposts modself modtraffic modwiki mysubreddits privatemessages read report save structuredstyles submit subscribe vote wikiedit wikiread\"}"
},
"headers": {
"Accept-Ranges": "bytes",
"Cache-Control": "max-age=0, must-revalidate",
"Connection": "close",
"Content-Length": "430",
"Content-Type": "application/json; charset=UTF-8",
"Date": "Wed, 24 Feb 2021 04:05:45 GMT",
"Server": "snooserv",
"Set-Cookie": "edgebucket=RJRxvqXbGD1qATlfHc; Domain=reddit.com; Max-Age=63071999; Path=/; secure",
"Strict-Transport-Security": "max-age=15552000; includeSubDomains; preload",
"Via": "1.1 varnish",
"X-Moose": "majestic",
"x-content-type-options": "nosniff",
"x-frame-options": "SAMEORIGIN",
"x-xss-protection": "1; mode=block"
},
"status": {
"code": 200,
"message": "OK"
},
"url": "https://www.reddit.com/api/v1/access_token"
}
}
],
"recorded_at": "2021-02-23T22:05:45",
"version": 1
}
98 changes: 91 additions & 7 deletions tests/test_authorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def test_initialize(self):
self.assertFalse(authorizer.is_valid())

def test_initialize__with_refresh_token(self):
authorizer = asyncprawcore.Authorizer(self.authentication, REFRESH_TOKEN)
authorizer = asyncprawcore.Authorizer(
self.authentication, refresh_token=REFRESH_TOKEN
)
self.assertIsNone(authorizer.access_token)
self.assertIsNone(authorizer.scopes)
self.assertEqual(REFRESH_TOKEN, authorizer.refresh_token)
Expand All @@ -88,7 +90,79 @@ def test_initialize__with_untrusted_authenticator(self):
self.assertFalse(authorizer.is_valid())

async def test_refresh(self):
authorizer = asyncprawcore.Authorizer(self.authentication, REFRESH_TOKEN)
authorizer = asyncprawcore.Authorizer(
self.authentication, refresh_token=REFRESH_TOKEN
)
with VCR.use_cassette("Authorizer_refresh"):
await authorizer.refresh()

self.assertIsNotNone(authorizer.access_token)
self.assertIsInstance(authorizer.scopes, set)
self.assertTrue(len(authorizer.scopes) > 0)
self.assertTrue(authorizer.is_valid())

async def test_refresh__post_refresh_callback(self):
def callback(authorizer):
self.assertNotEqual(REFRESH_TOKEN, authorizer.refresh_token)
authorizer.refresh_token = "manually_updated"

authorizer = asyncprawcore.Authorizer(
self.authentication,
post_refresh_callback=callback,
refresh_token=REFRESH_TOKEN,
)
with VCR.use_cassette("Authorizer_refresh"):
await authorizer.refresh()

self.assertIsNotNone(authorizer.access_token)
self.assertEqual("manually_updated", authorizer.refresh_token)
self.assertIsInstance(authorizer.scopes, set)
self.assertTrue(len(authorizer.scopes) > 0)
self.assertTrue(authorizer.is_valid())

async def test_refresh__post_refresh_callback__async(self):
async def callback(authorizer):
self.assertNotEqual(REFRESH_TOKEN, authorizer.refresh_token)
authorizer.refresh_token = "manually_updated"

authorizer = asyncprawcore.Authorizer(
self.authentication,
post_refresh_callback=callback,
refresh_token=REFRESH_TOKEN,
)
with VCR.use_cassette("Authorizer_refresh"):
await authorizer.refresh()

self.assertIsNotNone(authorizer.access_token)
self.assertEqual("manually_updated", authorizer.refresh_token)
self.assertIsInstance(authorizer.scopes, set)
self.assertTrue(len(authorizer.scopes) > 0)
self.assertTrue(authorizer.is_valid())

async def test_refresh__pre_refresh_callback(self):
def callback(authorizer):
self.assertIsNone(authorizer.refresh_token)
authorizer.refresh_token = REFRESH_TOKEN

authorizer = asyncprawcore.Authorizer(
self.authentication, pre_refresh_callback=callback
)
with VCR.use_cassette("Authorizer_refresh"):
await authorizer.refresh()

self.assertIsNotNone(authorizer.access_token)
self.assertIsInstance(authorizer.scopes, set)
self.assertTrue(len(authorizer.scopes) > 0)
self.assertTrue(authorizer.is_valid())

async def test_refresh__pre_refresh_callback__async(self):
async def callback(authorizer):
self.assertIsNone(authorizer.refresh_token)
authorizer.refresh_token = REFRESH_TOKEN

authorizer = asyncprawcore.Authorizer(
self.authentication, pre_refresh_callback=callback
)
with VCR.use_cassette("Authorizer_refresh"):
await authorizer.refresh()

Expand All @@ -98,7 +172,9 @@ async def test_refresh(self):
self.assertTrue(authorizer.is_valid())

async def test_refresh__with_invalid_token(self):
authorizer = asyncprawcore.Authorizer(self.authentication, "INVALID_TOKEN")
authorizer = asyncprawcore.Authorizer(
self.authentication, refresh_token="INVALID_TOKEN"
)
with VCR.use_cassette("Authorizer_refresh__with_invalid_token"):
with self.assertRaises(asyncprawcore.ResponseException):
await authorizer.refresh()
Expand All @@ -111,7 +187,9 @@ async def test_refresh__without_refresh_token(self):
self.assertFalse(authorizer.is_valid())

async def test_revoke__access_token_with_refresh_set(self):
authorizer = asyncprawcore.Authorizer(self.authentication, REFRESH_TOKEN)
authorizer = asyncprawcore.Authorizer(
self.authentication, refresh_token=REFRESH_TOKEN
)
with VCR.use_cassette("Authorizer_revoke__access_token_with_refresh_set"):
await authorizer.refresh()
await authorizer.revoke(only_access=True)
Expand All @@ -138,7 +216,9 @@ async def test_revoke__access_token_without_refresh_set(self):
self.assertFalse(authorizer.is_valid())

async def test_revoke__refresh_token_with_access_set(self):
authorizer = asyncprawcore.Authorizer(self.authentication, REFRESH_TOKEN)
authorizer = asyncprawcore.Authorizer(
self.authentication, refresh_token=REFRESH_TOKEN
)
with VCR.use_cassette("Authorizer_revoke__refresh_token_with_access_set"):
await authorizer.refresh()
await authorizer.revoke()
Expand All @@ -149,7 +229,9 @@ async def test_revoke__refresh_token_with_access_set(self):
self.assertFalse(authorizer.is_valid())

async def test_revoke__refresh_token_without_access_set(self):
authorizer = asyncprawcore.Authorizer(self.authentication, REFRESH_TOKEN)
authorizer = asyncprawcore.Authorizer(
self.authentication, refresh_token=REFRESH_TOKEN
)
with VCR.use_cassette("Authorizer_revoke__refresh_token_without_access_set"):
await authorizer.revoke()

Expand All @@ -159,7 +241,9 @@ async def test_revoke__refresh_token_without_access_set(self):
self.assertFalse(authorizer.is_valid())

async def test_revoke__without_access_token(self):
authorizer = asyncprawcore.Authorizer(self.authentication, REFRESH_TOKEN)
authorizer = asyncprawcore.Authorizer(
self.authentication, refresh_token=REFRESH_TOKEN
)
with self.assertRaises(asyncprawcore.InvalidInvocation):
await authorizer.revoke(only_access=True)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def client_authorizer():
authenticator = asyncprawcore.TrustedAuthenticator(
requestor, CLIENT_ID, CLIENT_SECRET
)
authorizer = asyncprawcore.Authorizer(authenticator, REFRESH_TOKEN)
authorizer = asyncprawcore.Authorizer(authenticator, refresh_token=REFRESH_TOKEN)
await authorizer.refresh()
return authorizer

Expand Down Expand Up @@ -401,7 +401,7 @@ async def test_request__unsupported_media_type(self):
await session.request("POST", "r/asyncpraw/api/wiki/edit/", data=data)
self.assertEqual(415, context_manager.exception.response.status)

async def test_request__with_insufficent_scope(self):
async def test_request__with_insufficient_scope(self):
with VCR.use_cassette("Session_request__with_insufficient_scope"):
self.session = asyncprawcore.Session(await client_authorizer())
with self.assertRaises(asyncprawcore.InsufficientScope):
Expand Down

0 comments on commit 496bc2c

Please sign in to comment.