Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: some additional logs #216

Merged
merged 17 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 15 additions & 19 deletions pyeudiw/federation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,33 @@
from .exceptions import InvalidEntityStatement, InvalidEntityConfiguration
from pyeudiw.federation.schemas.entity_configuration import EntityStatementPayload, EntityConfigurationPayload

def is_es(payload: dict) -> bool:
def is_es(payload: dict) -> None:
"""
Determines if payload dict is an Entity Statement
Determines if payload dict is a Subordinate Entity Statement

:param payload: the object to determine if is an Entity Statement
:param payload: the object to determine if is a Subordinate Entity Statement
:type payload: dict

:returns: True if is an Entity Statement and False otherwise
:rtype: bool
"""

try:
EntityStatementPayload(**payload)
if payload["iss"] != payload["sub"]:
return True
except Exception:
return False


def is_ec(payload: dict) -> bool:
if payload["iss"] == payload["sub"]:
_msg = f"Invalid Entity Statement: iss and sub cannot be the same"
raise InvalidEntityStatement(_msg)
except ValueError as e:
_msg = f"Invalid Entity Statement: {e}"
raise InvalidEntityStatement(_msg)

def is_ec(payload: dict) -> None:
"""
Determines if payload dict is an Entity Configuration

:param payload: the object to determine if is an Entity Configuration
:type payload: dict

:returns: True if is an Entity Configuration and False otherwise
:rtype: bool
"""

try:
EntityConfigurationPayload(**payload)
return True
except Exception as e:
return False
except ValueError as e:
_msg = f"Invalid Entity Configuration: {e}"
raise InvalidEntityConfiguration(_msg)
74 changes: 55 additions & 19 deletions pyeudiw/federation/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import asyncio
import requests

from .exceptions import HttpError

async def fetch(session: dict, url: str, httpc_params: dict) -> str:

async def fetch(session: aiohttp.ClientSession, url: str, httpc_params: dict) -> requests.Response:
"""
Fetches the content of a URL.

Expand All @@ -20,12 +22,11 @@ async def fetch(session: dict, url: str, httpc_params: dict) -> str:

async with session.get(url, **httpc_params.get("connection", {})) as response:
if response.status != 200: # pragma: no cover
# response.raise_for_status()
return ""
return await response.text()
response.raise_for_status()
return await response


async def fetch_all(session: dict, urls: list[str], httpc_params: dict) -> list[str]:
async def fetch_all(session: aiohttp.ClientSession, urls: list[str], httpc_params: dict) -> list[requests.Response]:
"""
Fetches the content of a list of URL.

Expand All @@ -36,6 +37,8 @@ async def fetch_all(session: dict, urls: list[str], httpc_params: dict) -> list[
:param httpc_params: parameters to perform http requests.
:type httpc_params: dict

:raises HttpError: if the response status code is not 200 or a connection error occurs

:returns: the list of responses in string format
:rtype: list[str]
"""
Expand All @@ -44,13 +47,21 @@ async def fetch_all(session: dict, urls: list[str], httpc_params: dict) -> list[
for url in urls:
task = asyncio.create_task(fetch(session, url, httpc_params))
tasks.append(task)
results = await asyncio.gather(*tasks)
return results

try:
results: list[requests.Response] = await asyncio.gather(*tasks)
except aiohttp.ClientConnectorError as e:
raise HttpError(f"Connection error: {e}")

for r in results:
if r.status_code != 200:
raise HttpError(f"HTTP error: {r.status_code} -- {r.reason}")

return results

async def http_get(urls, httpc_params: dict, sync=True):
def http_get_sync(urls, httpc_params: dict) -> list[requests.Response]:
"""
Perform a GET http call.
Perform a GET http call sync.

:param session: a dict representing the current session
:type session: dict
Expand All @@ -59,20 +70,45 @@ async def http_get(urls, httpc_params: dict, sync=True):
:param httpc_params: parameters to perform http requests.
:type httpc_params: dict

:returns: the list of responses in string format
:rtype: list[str]
:raises HttpError: if the response status code is not 200 or a connection error occurs

:returns: the list of responses
:rtype: list[requests.Response]
"""
if sync:
_conf = {
'verify': httpc_params['connection']['ssl'],
'timeout': httpc_params['session']['timeout']
}
_conf = {
'verify': httpc_params['connection']['ssl'],
'timeout': httpc_params['session']['timeout']
}
try:
res = [
requests.get(url, **_conf).content # nosec - B113
requests.get(url, **_conf) # nosec - B113
for url in urls
]
return res
except requests.exceptions.ConnectionError as e:
raise HttpError(f"Connection error: {e}")

for r in res:
if r.status_code != 200:
raise HttpError(f"HTTP error: {r.status_code} -- {r.reason}")

return res

async def http_get_async(urls, httpc_params: dict) -> list[requests.Response]:
"""
Perform a GET http call async.

:param session: a dict representing the current session
:type session: dict
:param urls: the url list where fetch the content
:type urls: list[str]
:param httpc_params: parameters to perform http requests.
:type httpc_params: dict

:raises HttpError: if the response status code is not 200 or a connection error occurs

:returns: the list of responses
:rtype: list[requests.Response]
"""
if not isinstance(httpc_params['session']['timeout'], aiohttp.ClientTimeout):
httpc_params['session']['timeout'] = aiohttp.ClientTimeout(
total=httpc_params['session']['timeout']
Expand All @@ -95,4 +131,4 @@ async def http_get(urls, httpc_params: dict, sync=True):
"http://127.0.0.1:8001/.well-known/openid-federation",
"http://google.it",
]
asyncio.run(http_get(urls, httpc_params=httpc_params))
asyncio.run(http_get_async(urls, httpc_params=httpc_params))
52 changes: 26 additions & 26 deletions pyeudiw/federation/statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,14 @@
EntityConfigurationHeader,
EntityStatementPayload
)
from pydantic import ValidationError
from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header
from pyeudiw.jwt import JWSHelper
from pyeudiw.tools.utils import get_http_url
from pydantic import ValidationError

from pyeudiw.jwk import find_jwk
from pyeudiw.tools.utils import get_http_url

import json
import logging

try:
pass
except ImportError: # pragma: no cover
pass


OIDCFED_FEDERATION_WELLKNOWN_URL = ".well-known/openid-federation"
logger = logging.getLogger(__name__)

Expand All @@ -49,9 +41,9 @@ def jwks_from_jwks_uri(jwks_uri: str, httpc_params: dict, http_async: bool = Tru
"""

response = get_http_url(jwks_uri, httpc_params, http_async)
jwks = json.loads(response)
jwks = [i.json() for i in response]

return [jwks]
return jwks


def get_federation_jwks(jwt_payload: dict) -> list[dict]:
Expand All @@ -67,11 +59,10 @@ def get_federation_jwks(jwt_payload: dict) -> list[dict]:

jwks = jwt_payload.get("jwks", {})
keys = jwks.get("keys", [])

return keys


def get_entity_statements(urls: list[str] | str, httpc_params: dict, http_async: bool = True) -> list[dict]:
def get_entity_statements(urls: list[str] | str, httpc_params: dict, http_async: bool = True) -> list[bytes]:
"""
Fetches an entity statement from the specified urls.

Expand All @@ -83,18 +74,20 @@ def get_entity_statements(urls: list[str] | str, httpc_params: dict, http_async:
:type http_async: bool

:returns: A list of entity statements.
:rtype: list[dict]
:rtype: list[Response]
"""

urls = urls if isinstance(urls, list) else [urls]

for url in urls:
logger.debug(f"Starting Entity Statement Request to {url}")

return get_http_url(urls, httpc_params, http_async)
return [
i.content for i in
get_http_url(urls, httpc_params, http_async)
]


def get_entity_configurations(subjects: list[str] | str, httpc_params: dict, http_async: bool = True):
def get_entity_configurations(subjects: list[str] | str, httpc_params: dict, http_async: bool = False) -> list[bytes]:
"""
Fetches an entity configuration from the specified subjects.

Expand All @@ -106,7 +99,7 @@ def get_entity_configurations(subjects: list[str] | str, httpc_params: dict, htt
:type http_async: bool

:returns: A list of entity statements.
:rtype: list[dict]
:rtype: list[Response]
"""

subjects = subjects if isinstance(subjects, list) else [subjects]
Expand All @@ -119,7 +112,10 @@ def get_entity_configurations(subjects: list[str] | str, httpc_params: dict, htt
urls.append(url)
logger.info(f"Starting Entity Configuration Request for {url}")

return get_http_url(urls, httpc_params, http_async)
return [
i.content for i in
get_http_url(urls, httpc_params, http_async)
]


class TrustMark:
Expand All @@ -145,7 +141,7 @@ def __init__(self, jwt: str, httpc_params: dict):

self.is_valid = False

self.issuer_entity_configuration = None
self.issuer_entity_configuration: list[bytes] = None
self.httpc_params = httpc_params

def validate_by(self, ec: dict) -> bool:
Expand Down Expand Up @@ -191,9 +187,12 @@ def validate_by_its_issuer(self) -> bool:
:rtype: bool
"""
if not self.issuer_entity_configuration:
self.issuer_entity_configuration = get_entity_configurations(
self.iss, self.httpc_params, False
)
self.issuer_entity_configuration = [
i.content for i in
get_entity_configurations(
self.iss, self.httpc_params, False
)
]

_kid = self.header.get('kid')
try:
Expand Down Expand Up @@ -232,7 +231,7 @@ def __init__(
jwt: str,
httpc_params: dict,
filter_by_allowed_trust_marks: list[str] = [],
trust_anchor_entity_conf: 'EntityStatement' | None = None,
trust_anchor_entity_conf: EntityStatement | None = None,
trust_mark_issuers_entity_confs: list[EntityStatement] = [],
):
"""
Expand Down Expand Up @@ -474,7 +473,8 @@ def get_superiors(

if not jwts:
jwts = get_entity_configurations(
authority_hints, self.httpc_params, False)
authority_hints, self.httpc_params, False
)

for jwt in jwts:
try:
Expand Down
Loading
Loading