diff --git a/pyeudiw/federation/__init__.py b/pyeudiw/federation/__init__.py index afdc418d..f6a699bf 100644 --- a/pyeudiw/federation/__init__.py +++ b/pyeudiw/federation/__init__.py @@ -1,37 +1,33 @@ +from .exceptions import InvalidEntityStatement, InvalidEntityConfiguration from pyeudiw.federation.schemas.entity_configuration import EntityStatementPayload, EntityConfigurationPayload -def is_es(payload: dict) -> bool: +def is_es(payload: dict) -> None: """ - Determines if payload dict is an Entity Statement + Determines if payload dict is a Subordinate Entity Statement - :param payload: the object to determine if is an Entity Statement + :param payload: the object to determine if is a Subordinate Entity Statement :type payload: dict - - :returns: True if is an Entity Statement and False otherwise - :rtype: bool """ try: EntityStatementPayload(**payload) - if payload["iss"] != payload["sub"]: - return True - except Exception: - return False - - -def is_ec(payload: dict) -> bool: + if payload["iss"] == payload["sub"]: + _msg = f"Invalid Entity Statement: iss and sub cannot be the same" + raise InvalidEntityStatement(_msg) + except ValueError as e: + _msg = f"Invalid Entity Statement: {e}" + raise InvalidEntityStatement(_msg) + +def is_ec(payload: dict) -> None: """ Determines if payload dict is an Entity Configuration :param payload: the object to determine if is an Entity Configuration :type payload: dict - - :returns: True if is an Entity Configuration and False otherwise - :rtype: bool """ try: EntityConfigurationPayload(**payload) - return True - except Exception as e: - return False \ No newline at end of file + except ValueError as e: + _msg = f"Invalid Entity Configuration: {e}" + raise InvalidEntityConfiguration(_msg) diff --git a/pyeudiw/federation/http_client.py b/pyeudiw/federation/http_client.py index a7cb3b04..2030c809 100644 --- a/pyeudiw/federation/http_client.py +++ b/pyeudiw/federation/http_client.py @@ -2,8 +2,10 @@ import asyncio import requests +from .exceptions import HttpError -async def fetch(session: dict, url: str, httpc_params: dict) -> str: + +async def fetch(session: aiohttp.ClientSession, url: str, httpc_params: dict) -> requests.Response: """ Fetches the content of a URL. @@ -20,12 +22,11 @@ async def fetch(session: dict, url: str, httpc_params: dict) -> str: async with session.get(url, **httpc_params.get("connection", {})) as response: if response.status != 200: # pragma: no cover - # response.raise_for_status() - return "" - return await response.text() + response.raise_for_status() + return await response -async def fetch_all(session: dict, urls: list[str], httpc_params: dict) -> list[str]: +async def fetch_all(session: aiohttp.ClientSession, urls: list[str], httpc_params: dict) -> list[requests.Response]: """ Fetches the content of a list of URL. @@ -36,6 +37,8 @@ async def fetch_all(session: dict, urls: list[str], httpc_params: dict) -> list[ :param httpc_params: parameters to perform http requests. :type httpc_params: dict + :raises HttpError: if the response status code is not 200 or a connection error occurs + :returns: the list of responses in string format :rtype: list[str] """ @@ -44,13 +47,21 @@ async def fetch_all(session: dict, urls: list[str], httpc_params: dict) -> list[ for url in urls: task = asyncio.create_task(fetch(session, url, httpc_params)) tasks.append(task) - results = await asyncio.gather(*tasks) - return results + try: + results: list[requests.Response] = await asyncio.gather(*tasks) + except aiohttp.ClientConnectorError as e: + raise HttpError(f"Connection error: {e}") + + for r in results: + if r.status_code != 200: + raise HttpError(f"HTTP error: {r.status_code} -- {r.reason}") + + return results -async def http_get(urls, httpc_params: dict, sync=True): +def http_get_sync(urls, httpc_params: dict) -> list[requests.Response]: """ - Perform a GET http call. + Perform a GET http call sync. :param session: a dict representing the current session :type session: dict @@ -59,20 +70,45 @@ async def http_get(urls, httpc_params: dict, sync=True): :param httpc_params: parameters to perform http requests. :type httpc_params: dict - :returns: the list of responses in string format - :rtype: list[str] + :raises HttpError: if the response status code is not 200 or a connection error occurs + + :returns: the list of responses + :rtype: list[requests.Response] """ - if sync: - _conf = { - 'verify': httpc_params['connection']['ssl'], - 'timeout': httpc_params['session']['timeout'] - } + _conf = { + 'verify': httpc_params['connection']['ssl'], + 'timeout': httpc_params['session']['timeout'] + } + try: res = [ - requests.get(url, **_conf).content # nosec - B113 + requests.get(url, **_conf) # nosec - B113 for url in urls ] - return res + except requests.exceptions.ConnectionError as e: + raise HttpError(f"Connection error: {e}") + + for r in res: + if r.status_code != 200: + raise HttpError(f"HTTP error: {r.status_code} -- {r.reason}") + + return res +async def http_get_async(urls, httpc_params: dict) -> list[requests.Response]: + """ + Perform a GET http call async. + + :param session: a dict representing the current session + :type session: dict + :param urls: the url list where fetch the content + :type urls: list[str] + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + + :raises HttpError: if the response status code is not 200 or a connection error occurs + + :returns: the list of responses + :rtype: list[requests.Response] + """ if not isinstance(httpc_params['session']['timeout'], aiohttp.ClientTimeout): httpc_params['session']['timeout'] = aiohttp.ClientTimeout( total=httpc_params['session']['timeout'] @@ -95,4 +131,4 @@ async def http_get(urls, httpc_params: dict, sync=True): "http://127.0.0.1:8001/.well-known/openid-federation", "http://google.it", ] - asyncio.run(http_get(urls, httpc_params=httpc_params)) + asyncio.run(http_get_async(urls, httpc_params=httpc_params)) diff --git a/pyeudiw/federation/statements.py b/pyeudiw/federation/statements.py index b74012fa..73e30667 100644 --- a/pyeudiw/federation/statements.py +++ b/pyeudiw/federation/statements.py @@ -13,22 +13,14 @@ EntityConfigurationHeader, EntityStatementPayload ) +from pydantic import ValidationError from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header from pyeudiw.jwt import JWSHelper -from pyeudiw.tools.utils import get_http_url -from pydantic import ValidationError - from pyeudiw.jwk import find_jwk +from pyeudiw.tools.utils import get_http_url -import json import logging -try: - pass -except ImportError: # pragma: no cover - pass - - OIDCFED_FEDERATION_WELLKNOWN_URL = ".well-known/openid-federation" logger = logging.getLogger(__name__) @@ -49,9 +41,9 @@ def jwks_from_jwks_uri(jwks_uri: str, httpc_params: dict, http_async: bool = Tru """ response = get_http_url(jwks_uri, httpc_params, http_async) - jwks = json.loads(response) + jwks = [i.json() for i in response] - return [jwks] + return jwks def get_federation_jwks(jwt_payload: dict) -> list[dict]: @@ -67,11 +59,10 @@ def get_federation_jwks(jwt_payload: dict) -> list[dict]: jwks = jwt_payload.get("jwks", {}) keys = jwks.get("keys", []) - return keys -def get_entity_statements(urls: list[str] | str, httpc_params: dict, http_async: bool = True) -> list[dict]: +def get_entity_statements(urls: list[str] | str, httpc_params: dict, http_async: bool = True) -> list[bytes]: """ Fetches an entity statement from the specified urls. @@ -83,18 +74,20 @@ def get_entity_statements(urls: list[str] | str, httpc_params: dict, http_async: :type http_async: bool :returns: A list of entity statements. - :rtype: list[dict] + :rtype: list[Response] """ urls = urls if isinstance(urls, list) else [urls] - for url in urls: logger.debug(f"Starting Entity Statement Request to {url}") - return get_http_url(urls, httpc_params, http_async) + return [ + i.content for i in + get_http_url(urls, httpc_params, http_async) + ] -def get_entity_configurations(subjects: list[str] | str, httpc_params: dict, http_async: bool = True): +def get_entity_configurations(subjects: list[str] | str, httpc_params: dict, http_async: bool = False) -> list[bytes]: """ Fetches an entity configuration from the specified subjects. @@ -106,7 +99,7 @@ def get_entity_configurations(subjects: list[str] | str, httpc_params: dict, htt :type http_async: bool :returns: A list of entity statements. - :rtype: list[dict] + :rtype: list[Response] """ subjects = subjects if isinstance(subjects, list) else [subjects] @@ -119,7 +112,10 @@ def get_entity_configurations(subjects: list[str] | str, httpc_params: dict, htt urls.append(url) logger.info(f"Starting Entity Configuration Request for {url}") - return get_http_url(urls, httpc_params, http_async) + return [ + i.content for i in + get_http_url(urls, httpc_params, http_async) + ] class TrustMark: @@ -145,7 +141,7 @@ def __init__(self, jwt: str, httpc_params: dict): self.is_valid = False - self.issuer_entity_configuration = None + self.issuer_entity_configuration: list[bytes] = None self.httpc_params = httpc_params def validate_by(self, ec: dict) -> bool: @@ -191,9 +187,12 @@ def validate_by_its_issuer(self) -> bool: :rtype: bool """ if not self.issuer_entity_configuration: - self.issuer_entity_configuration = get_entity_configurations( - self.iss, self.httpc_params, False - ) + self.issuer_entity_configuration = [ + i.content for i in + get_entity_configurations( + self.iss, self.httpc_params, False + ) + ] _kid = self.header.get('kid') try: @@ -232,7 +231,7 @@ def __init__( jwt: str, httpc_params: dict, filter_by_allowed_trust_marks: list[str] = [], - trust_anchor_entity_conf: 'EntityStatement' | None = None, + trust_anchor_entity_conf: EntityStatement | None = None, trust_mark_issuers_entity_confs: list[EntityStatement] = [], ): """ @@ -474,7 +473,8 @@ def get_superiors( if not jwts: jwts = get_entity_configurations( - authority_hints, self.httpc_params, False) + authority_hints, self.httpc_params, False + ) for jwt in jwts: try: diff --git a/pyeudiw/federation/trust_chain_validator.py b/pyeudiw/federation/trust_chain_validator.py index 6426d9b8..bafa6081 100644 --- a/pyeudiw/federation/trust_chain_validator.py +++ b/pyeudiw/federation/trust_chain_validator.py @@ -12,7 +12,8 @@ HttpError, MissingTrustAnchorPublicKey, TimeValidationError, - KeyValidationError + KeyValidationError, + InvalidEntityStatement ) from pyeudiw.jwk import find_jwk @@ -132,18 +133,27 @@ def validate(self) -> bool: ) if not ta_jwk: + logger.error( + f"Trust chain validation error: TA jwks not found." + ) return False # Validate the last statement with ta_jwk jwsh = JWSHelper(ta_jwk) if not jwsh.verify(last_element): + logger.error( + f"Trust chain signature validation error: {last_element} using {ta_jwk}" + ) return False # then go ahead with other checks self.exp = es_payload["exp"] if self._check_expired(self.exp): + logger.error( + f"Trust chain validation error, statement expired: {es_payload}" + ) return False fed_jwks = es_payload["jwks"]["keys"] @@ -160,10 +170,16 @@ def validate(self) -> bool: st_header.get("kid", None), fed_jwks ) except (KidNotFoundError, InvalidKid): + logger.error( + f"Trust chain validation KidNotFoundError: {st_header} not in {fed_jwks}" + ) return False jwsh = JWSHelper(jwk) if not jwsh.verify(st): + logger.error( + f"Trust chain signature validation error: {st} using {jwk}" + ) return False else: fed_jwks = st_payload["jwks"]["keys"] @@ -183,11 +199,6 @@ def _retrieve_ec(self, iss: str) -> str: :rtype: str """ jwt = get_entity_configurations(iss, self.httpc_params) - if not jwt: - raise HttpError( - f"Cannot get the Entity Configuration from {iss}") - - # is something weird these will raise their Exceptions return jwt[0] def _retrieve_es(self, download_url: str, iss: str) -> str: @@ -203,17 +214,11 @@ def _retrieve_es(self, download_url: str, iss: str) -> str: :rtype: str """ jwt = get_entity_statements(download_url, self.httpc_params) - if not jwt: - logger.warning( - f"Cannot fast refresh Entity Statement {iss}" - ) - if isinstance(jwt, list) and jwt: - return jwt[0] - return jwt + return jwt[0] def _update_st(self, st: str) -> str: """ - Updates the statement retrieving the new one using the source end_point and the sub fields of st payload. + Updates the statement retrieving the new one using the source_endpoint and the sub fields of the entity statement payload. :param st: The statement in form of a JWT. :type st: str @@ -223,8 +228,11 @@ def _update_st(self, st: str) -> str: """ payload = decode_jwt_payload(st) iss = payload['iss'] - if not is_es(payload): + + try: + is_es(payload) # It's an entity configuration + except InvalidEntityStatement: return self._retrieve_ec(iss) # if it has the source_endpoint let's try a fast renewal diff --git a/pyeudiw/jwt/utils.py b/pyeudiw/jwt/utils.py index cf0285ea..5c3cf2d2 100644 --- a/pyeudiw/jwt/utils.py +++ b/pyeudiw/jwt/utils.py @@ -131,7 +131,6 @@ def is_jws_format(jwt: str): :returns: True if the string is a JWS, False otherwise. :rtype: bool """ - breakpoint() if not is_jwt_format(jwt): return False diff --git a/pyeudiw/oauth2/dpop/__init__.py b/pyeudiw/oauth2/dpop/__init__.py index 986efd44..edfdffdf 100644 --- a/pyeudiw/oauth2/dpop/__init__.py +++ b/pyeudiw/oauth2/dpop/__init__.py @@ -101,6 +101,8 @@ def __init__( "Jwk validation error, " f"{e.__class__.__name__}: {e}" ) + raise ValueError("JWK schema validation error during DPoP init") + # If the jwt is invalid, this will raise an exception try: decode_jwt_header(http_header_dpop) diff --git a/pyeudiw/openid4vp/direct_post_response.py b/pyeudiw/openid4vp/direct_post_response.py index a51ad414..3007d12e 100644 --- a/pyeudiw/openid4vp/direct_post_response.py +++ b/pyeudiw/openid4vp/direct_post_response.py @@ -1,9 +1,12 @@ +import logging + from typing import Dict from pyeudiw.jwk import JWK from pyeudiw.jwt import JWEHelper, JWSHelper from pyeudiw.jwk.exceptions import KidNotFoundError from pyeudiw.jwt.utils import decode_jwt_header, is_jwe_format from pyeudiw.openid4vp.exceptions import ( + InvalidVPToken, VPNotFound, VPInvalidNonce, NoNonceInVPToken @@ -12,6 +15,9 @@ from pyeudiw.openid4vp.vp import Vp from pydantic import ValidationError +logger = logging.getLogger(__name__) + + class DirectPostResponse: """ Helper class for generate Direct Post Response. @@ -90,8 +96,10 @@ def _validate_vp(self, vp: dict) -> bool: ) VPTokenPayload(**vp.payload) VPTokenHeader(**vp.headers) - except ValidationError: - return False + except ValidationError as e: + raise InvalidVPToken( + f"VP is not valid, {e}: {vp.headers}.{vp.payload}" + ) return True @@ -102,12 +110,19 @@ def validate(self) -> bool: :returns: True if all VP are valid, False otherwhise. :rtype: bool """ - + all_valid = None for vp in self.get_presentation_vps(): - if not self._validate_vp(vp): - return False - - return True + try: + self._validate_vp(vp) + if all_valid == None: + all_valid = True + except Exception as e: + logger.error( + + ) + all_valid = False + + return all_valid def get_presentation_vps(self) -> list[Vp]: """ @@ -125,16 +140,18 @@ def get_presentation_vps(self) -> list[Vp]: vps = [_vps] if isinstance(_vps, str) else _vps if not vps: - raise VPNotFound(f"Vps are empty for response with nonce \"{self.nonce}\"") + raise VPNotFound( + f'Vps are empty for response with nonce "{self.nonce}"' + ) for vp in vps: + # TODO - add an exception handling here _vp = Vp(vp) self._vps.append(_vp) cred_iss = _vp.credential_payload['iss'] if not self.credentials_by_issuer.get(cred_iss, None): self.credentials_by_issuer[cred_iss] = [] - self.credentials_by_issuer[cred_iss].append(_vp.payload['vp']) return self._vps @@ -151,4 +168,4 @@ def payload(self) -> dict: """Returns the decoded payload of presentation""" if not self._payload: self._decode_payload() - return self._payload \ No newline at end of file + return self._payload diff --git a/pyeudiw/satosa/dpop.py b/pyeudiw/satosa/dpop.py index 778b0696..7fb4b419 100644 --- a/pyeudiw/satosa/dpop.py +++ b/pyeudiw/satosa/dpop.py @@ -12,6 +12,7 @@ from pyeudiw.tools.base_logger import BaseLogger from .base_http_error_handler import BaseHTTPErrorHandler + class BackendDPoP(BaseHTTPErrorHandler, BaseLogger): """ Backend DPoP class. @@ -43,16 +44,20 @@ def _request_endpoint_dpop(self, context: Context, *args) -> Union[JsonResponse, try: WalletInstanceAttestationHeader(**_head) except ValidationError as e: - self._log_warning(context, message=f"[FOUND WIA] Invalid Headers: {_head}! \nValidation error: {e}") + self._log_warning(context, message=f"[FOUND WIA] Invalid Headers: {_head}. Validation error: {e}") except Exception as e: - self._log_warning(context, message=f"[FOUND WIA] Invalid Headers: {_head}! \nUnexpected error: {e}") + self._log_warning(context, message=f"[FOUND WIA] Invalid Headers: {_head}. Unexpected error: {e}") try: WalletInstanceAttestationPayload(**wia) except ValidationError as e: - self._log_warning(context, message=f"[FOUND WIA] Invalid WIA: {wia}! \nValidation error: {e}") + _msg = f"[FOUND WIA] Invalid WIA: {wia}. Validation error: {e}" + self._log_warning(context, message=_msg) + # return self._handle_401(context, _msg, e) except Exception as e: - self._log_warning(context, message=f"[FOUND WIA] Invalid WIA: {wia}! \nUnexpected error: {e}") + _msg = f"[FOUND WIA] Invalid WIA: {wia}. Unexpected error: {e}" + self._log_warning(context, message=_msg) + # return self._handle_401(context, _msg, e) try: self._validate_trust(context, dpop_jws) @@ -84,7 +89,7 @@ def _request_endpoint_dpop(self, context: Context, *args) -> Union[JsonResponse, else: _msg = ( - "The Wallet Instance doesn't provide a valid Wallet Instance Attestation " + "The Wallet Instance doesn't provide a valid Wallet Attestation " "a default set of capabilities and a low security level are applied." ) - self._log_warning(context, message=_msg) \ No newline at end of file + self._log_warning(context, message=_msg) diff --git a/pyeudiw/satosa/trust.py b/pyeudiw/satosa/trust.py index 86540b31..a3232444 100644 --- a/pyeudiw/satosa/trust.py +++ b/pyeudiw/satosa/trust.py @@ -16,6 +16,7 @@ from pyeudiw.tools.base_logger import BaseLogger + class BackendTrust(BaseLogger): """ Backend Trust class. @@ -41,7 +42,10 @@ def init_trust_resources(self) -> None: try: self.get_backend_trust_chain() except Exception as e: - self._log_critical("Backend Trust", f"Cannot fetch the trust anchor configuration: {e}") + self._log_critical( + "Backend Trust", + f"Cannot fetch the trust anchor configuration: {e}" + ) self.db_engine.close() self._db_engine = None @@ -57,10 +61,9 @@ def entity_configuration_endpoint(self, context: Context) -> Response: :rtype: Response """ - data = self.entity_configuration_as_dict if context.qs_params.get('format', '') == 'json': return Response( - json.dumps(data), + json.dumps(self.entity_configuration_as_dict), status="200", content="application/json" ) @@ -101,21 +104,21 @@ def get_backend_trust_chain(self) -> list[str]: """ try: trust_evaluation_helper = TrustEvaluationHelper.build_trust_chain_for_entity_id( - storage=self.db_engine, - entity_id=self.client_id, - entity_configuration=self.entity_configuration, - httpc_params=self.config['network']['httpc_params'] + storage = self.db_engine, + entity_id = self.client_id, + entity_configuration = self.entity_configuration, + httpc_params = self.config['network']['httpc_params'] ) self.db_engine.add_or_update_trust_attestation( - entity_id=self.client_id, - attestation=trust_evaluation_helper.trust_chain, - exp=trust_evaluation_helper.exp + entity_id = self.client_id, + attestation = trust_evaluation_helper.trust_chain, + exp = trust_evaluation_helper.exp ) return trust_evaluation_helper.trust_chain except (DiscoveryFailedError, EntryNotFound, Exception) as e: message = ( - f"Error while building trust chain for client with id: {self.client_id}\n" + f"Error while building trust chain for client with id: {self.client_id}. " f"{e.__class__.__name__}: {e}" ) self._log_warning("Trust Chain", message) @@ -154,7 +157,6 @@ def _validate_trust(self, context: Context, jws: str) -> TrustEvaluationHelper: f"{trust_eval.entity_id}" ) self._log_error(context, message) - raise NotTrustedFederationError( f"{trust_eval.entity_id} not found for Trust evaluation." ) @@ -164,7 +166,6 @@ def _validate_trust(self, context: Context, jws: str) -> TrustEvaluationHelper: f"{trust_eval.entity_id}: {e}" ) self._log_error(context, message) - raise NotTrustedFederationError( f"{trust_eval.entity_id} is not trusted." ) @@ -208,4 +209,4 @@ def entity_configuration(self) -> dict: "typ": "entity-statement+jwt" }, plain_dict=data - ) \ No newline at end of file + ) diff --git a/pyeudiw/tests/federation/test_schema.py b/pyeudiw/tests/federation/test_schema.py index fcfe5b10..e8d50dba 100644 --- a/pyeudiw/tests/federation/test_schema.py +++ b/pyeudiw/tests/federation/test_schema.py @@ -1,6 +1,7 @@ from pyeudiw.tools.utils import iat_now, exp_from_now from pyeudiw.federation import is_es, is_ec +from pyeudiw.federation.exceptions import InvalidEntityStatement, InvalidEntityConfiguration NOW = iat_now() EXP = exp_from_now(5) @@ -127,16 +128,22 @@ def test_is_es(): - assert is_es(ta_es) + is_es(ta_es) def test_is_es_false(): - assert not is_es(ta_ec) + try: + is_es(ta_ec) + except InvalidEntityStatement as e: + pass def test_is_ec(): - assert is_ec(ta_ec) + is_ec(ta_ec) def test_is_ec_false(): - assert not is_ec(ta_es) + try: + is_ec(ta_es) + except InvalidEntityConfiguration as e: + pass diff --git a/pyeudiw/tests/federation/test_static_trust_chain_validator.py b/pyeudiw/tests/federation/test_static_trust_chain_validator.py index 1f48df13..c5638dc4 100644 --- a/pyeudiw/tests/federation/test_static_trust_chain_validator.py +++ b/pyeudiw/tests/federation/test_static_trust_chain_validator.py @@ -56,17 +56,7 @@ def test_is_valid_equals_false(): invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params ).is_valid - -def test_retrieve_ec(): - tcv.get_entity_configurations = Mock(return_value=[leaf_wallet_signed]) - - assert tcv.StaticTrustChainValidator( - invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params)._retrieve_ec("https://trust-anchor.example.org") == leaf_wallet_signed - - def test_retrieve_ec_fails(): - tcv.get_entity_configurations = Mock(return_value=[]) - try: StaticTrustChainValidator( invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params)._retrieve_ec("https://trust-anchor.example.org") @@ -74,15 +64,22 @@ def test_retrieve_ec_fails(): return +def test_retrieve_ec(): + tcv.get_entity_configurations = Mock(return_value=[leaf_wallet_signed]) + + assert tcv.StaticTrustChainValidator( + invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params)._retrieve_ec("https://trust-anchor.example.org") == leaf_wallet_signed + + def test_retrieve_es(): - tcv.get_entity_statements = Mock(return_value=ta_es) + tcv.get_entity_statements = Mock(return_value=[ta_es]) assert tcv.StaticTrustChainValidator( invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params)._retrieve_es("https://trust-anchor.example.org", "https://trust-anchor.example.org") == ta_es def test_retrieve_es_output_is_none(): - tcv.get_entity_statements = Mock(return_value=None) + tcv.get_entity_statements = Mock(return_value=[None]) assert tcv.StaticTrustChainValidator( invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params)._retrieve_es("https://trust-anchor.example.org", "https://trust-anchor.example.org") == None @@ -114,7 +111,7 @@ def test_update_st_es_case_source_endpoint(): ta_es_signed = ta_signer.sign_compact([ta_jwk]) def mock_method(*args, **kwargs): - return leaf_wallet_signed + return [leaf_wallet_signed] with mock.patch.object(tcv, "get_entity_statements", mock_method): _t = tcv.StaticTrustChainValidator( @@ -140,7 +137,7 @@ def mock_method_ec(*args, **kwargs): return [intermediate_es_wallet_signed] def mock_method_es(*args, **kwargs): - return leaf_wallet_signed + return [leaf_wallet_signed] with mock.patch.object(tcv, "get_entity_statements", mock_method_es): with mock.patch.object(tcv, "get_entity_configurations", mock_method_ec): diff --git a/pyeudiw/tests/satosa/test_backend.py b/pyeudiw/tests/satosa/test_backend.py index 2caa4887..e38d30c7 100644 --- a/pyeudiw/tests/satosa/test_backend.py +++ b/pyeudiw/tests/satosa/test_backend.py @@ -245,7 +245,7 @@ def test_vp_validation_in_redirect_endpoint(self, context): assert request_endpoint.status == "400" msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" - assert msg["error_description"] == "Error while validating VP: unexpected value." + assert msg["error_description"] # Recreate data without nonce # This will trigger a `NoNonceInVPToken` error @@ -287,7 +287,7 @@ def test_vp_validation_in_redirect_endpoint(self, context): assert request_endpoint.status == "400" msg = json.loads(request_endpoint.message) assert msg["error"] == "invalid_request" - assert msg["error_description"] == "Error while validating VP: vp has no nonce." + assert msg["error_description"] # This will trigger a `UnicodeDecodeError` which will be caught by the generic `Exception case`. response["vp_token"] = "asd.fgh.jkl" diff --git a/pyeudiw/tools/utils.py b/pyeudiw/tools/utils.py index cec68985..05f98238 100644 --- a/pyeudiw/tools/utils.py +++ b/pyeudiw/tools/utils.py @@ -5,7 +5,7 @@ import requests from secrets import token_hex -from pyeudiw.federation.http_client import http_get +from pyeudiw.federation.http_client import http_get_sync, http_get_async logger = logging.getLogger(__name__) @@ -66,7 +66,7 @@ def datetime_from_timestamp(timestamp: int | float) -> datetime.datetime: return make_timezone_aware(datetime.datetime.fromtimestamp(timestamp)) -def get_http_url(urls: list[str] | str, httpc_params: dict, http_async: bool = True) -> list[dict]: +def get_http_url(urls: list[str] | str, httpc_params: dict, http_async: bool = True) -> list[requests.Response]: """ Perform an HTTP Request returning the payload of the call. @@ -84,12 +84,9 @@ def get_http_url(urls: list[str] | str, httpc_params: dict, http_async: bool = T if http_async: responses = asyncio.run( - http_get(urls, httpc_params)) # pragma: no cover + http_get_async(urls, httpc_params)) # pragma: no cover else: - responses = [] - for i in urls: - res = requests.get(i, **httpc_params) # nosec - B113 - responses.append(res.content) + responses = http_get_sync(urls, httpc_params) return responses @@ -116,7 +113,7 @@ def get_jwks(httpc_params: dict, metadata: dict, federation_jwks: list[dict] = [ jwks_list = get_http_url( [jwks_uri], httpc_params=httpc_params ) - jwks_list = json.loads(jwks_list[0]) + jwks_list = jwks_list[0].json() except Exception as e: logger.error(f"Failed to download jwks from {jwks_uri}: {e}") elif metadata.get('signed_jwks_uri'): @@ -124,7 +121,7 @@ def get_jwks(httpc_params: dict, metadata: dict, federation_jwks: list[dict] = [ signed_jwks_uri = metadata["signed_jwks_uri"] jwks_list = get_http_url( [signed_jwks_uri], httpc_params=httpc_params - )[0] + )[0].json() except Exception as e: logger.error( f"Failed to download jwks from {signed_jwks_uri}: {e}") diff --git a/pyeudiw/trust/__init__.py b/pyeudiw/trust/__init__.py index 7952fae1..a61b86c5 100644 --- a/pyeudiw/trust/__init__.py +++ b/pyeudiw/trust/__init__.py @@ -261,7 +261,8 @@ def discovery(self, entity_id: str, entity_configuration: EntityStatement | None is_good = tcbuilder.is_valid if not is_good: raise DiscoveryFailedError( - f"Discovery failed for entity {entity_id}\nwith configuration {entity_configuration}") + f"Discovery failed for entity {entity_id} with configuration {entity_configuration}" + ) @staticmethod def build_trust_chain_for_entity_id(storage: DBEngine, entity_id, entity_configuration, httpc_params): diff --git a/pyeudiw/x509/verify.py b/pyeudiw/x509/verify.py index a5835318..768f73f2 100644 --- a/pyeudiw/x509/verify.py +++ b/pyeudiw/x509/verify.py @@ -9,6 +9,7 @@ logger = logging.getLogger(__name__) + def _verify_x509_certificate_chain(pems: list[str]): """ Verify the x509 certificate chain. @@ -21,25 +22,28 @@ def _verify_x509_certificate_chain(pems: list[str]): """ try: store = crypto.X509Store() - - x509_certs = [crypto.load_certificate(crypto.FILETYPE_PEM, str(pem)) for pem in pems] + x509_certs = [ + crypto.load_certificate(crypto.FILETYPE_PEM, str(pem)) + for pem in pems + ] for cert in x509_certs[:-1]: store.add_cert(cert) store_ctx = crypto.X509StoreContext(store, x509_certs[-1]) - store_ctx.verify_certificate() return True except crypto.Error as e: _message = f"cert's chain result invalid for the following reason -> {e}" logging.warning(LOG_ERROR.format(_message)) + return False except Exception as e: _message = f"cert's chain cannot be validated for error -> {e}" logging.warning(LOG_ERROR.format(e)) return False - + + def _check_chain_len(pems: list) -> bool: """ Check the x509 certificate chain lenght. @@ -50,16 +54,15 @@ def _check_chain_len(pems: list) -> bool: :returns: True if the x509 certificate chain lenght is valid else False :rtype: bool """ - chain_len = len(pems) - if chain_len < 2: message = f"invalid chain lenght -> minimum expected 2 found {chain_len}" logging.warning(LOG_ERROR.format(message)) return False return True - + + def _check_datetime(exp: datetime | None): """ Check the x509 certificate chain expiration date. @@ -80,6 +83,7 @@ def _check_datetime(exp: datetime | None): return True + def verify_x509_attestation_chain(x5c: list[bytes], exp: datetime | None = None) -> bool: """ Verify the x509 attestation certificate chain. @@ -99,7 +103,8 @@ def verify_x509_attestation_chain(x5c: list[bytes], exp: datetime | None = None) pems = [DER_cert_to_PEM_cert(cert) for cert in x5c] return _verify_x509_certificate_chain(pems) - + + def verify_x509_anchor(pem_str: str, exp: datetime | None = None) -> bool: """ Verify the x509 anchor certificate. @@ -113,15 +118,18 @@ def verify_x509_anchor(pem_str: str, exp: datetime | None = None) -> bool: :rtype: bool """ if not _check_datetime(exp): + logging.error(LOG_ERROR.format("check datetime failed")) return False pems = [str(cert) for cert in pem.parse(pem_str)] if not _check_chain_len(pems): + logging.error(LOG_ERROR.format("check chain len failed")) return False return _verify_x509_certificate_chain(pems) + def get_issuer_from_x5c(x5c: list[bytes]) -> str: """ Get the issuer from the x509 certificate chain. @@ -135,6 +143,7 @@ def get_issuer_from_x5c(x5c: list[bytes]) -> str: cert = load_der_x509_certificate(x5c[-1]) return cert.subject.rfc4514_string().split("=")[1] + def is_der_format(cert: bytes) -> str: """ Check if the certificate is in DER format. @@ -149,5 +158,6 @@ def is_der_format(cert: bytes) -> str: pem = DER_cert_to_PEM_cert(cert) crypto.load_certificate(crypto.FILETYPE_PEM, str(pem)) return True - except crypto.Error: + except crypto.Error as e: + logging.error(LOG_ERROR.format(e)) return False