diff --git a/pyeudiw/federation/statements.py b/pyeudiw/federation/statements.py index 9fd20b44..b74012fa 100644 --- a/pyeudiw/federation/statements.py +++ b/pyeudiw/federation/statements.py @@ -13,11 +13,12 @@ EntityConfigurationHeader, EntityStatementPayload ) -from pyeudiw.jwt.utils import unpad_jwt_payload, unpad_jwt_header +from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header from pyeudiw.jwt import JWSHelper from pyeudiw.tools.utils import get_http_url from pydantic import ValidationError +from pyeudiw.jwk import find_jwk import json import logging @@ -135,8 +136,8 @@ def __init__(self, jwt: str, httpc_params: dict): """ self.jwt = jwt - self.header = unpad_jwt_header(jwt) - self.payload = unpad_jwt_payload(jwt) + self.header = decode_jwt_header(jwt) + self.payload = decode_jwt_payload(jwt) self.id = self.payload["id"] self.sub = self.payload["sub"] @@ -165,15 +166,19 @@ def validate_by(self, ec: dict) -> bool: f"Trust Mark validation failed: " f"{e}" ) + + _kid = self.header["kid"] - - if self.header.get("kid") not in ec.kids: + if _kid not in ec.kids: raise UnknownKid( # pragma: no cover f"Trust Mark validation failed: " f"{self.header.get('kid')} not found in {ec.jwks}" ) + + _jwk = find_jwk(_kid, ec.jwks) + # verify signature - jwsh = JWSHelper(ec.jwks[ec.kids.index(self.header["kid"])]) + jwsh = JWSHelper(_jwk) payload = jwsh.verify(self.jwt) self.is_valid = True return payload @@ -189,13 +194,15 @@ def validate_by_its_issuer(self) -> bool: self.issuer_entity_configuration = get_entity_configurations( self.iss, self.httpc_params, False ) + + _kid = self.header.get('kid') try: ec = EntityStatement(self.issuer_entity_configuration[0]) ec.validate_by_itself() except UnknownKid: logger.warning( f"Trust Mark validation failed by its Issuer: " - f"{self.header.get('kid')} not found in " + f"{_kid} not found in " f"{self.issuer_entity_configuration.jwks}") return False except Exception: @@ -205,7 +212,8 @@ def validate_by_its_issuer(self) -> bool: return False # verify signature - jwsh = JWSHelper(ec.jwks[ec.kids.index(self.header["kid"])]) + _jwk = find_jwk(_kid, ec.jwks) + jwsh = JWSHelper(_jwk) payload = jwsh.verify(self.jwt) self.is_valid = True return payload @@ -241,8 +249,8 @@ def __init__( :param trust_mark_issuers_entity_confs: the list containig the trust mark's entiity confs """ self.jwt = jwt - self.header = unpad_jwt_header(jwt) - self.payload = unpad_jwt_payload(jwt) + self.header = decode_jwt_header(jwt) + self.payload = decode_jwt_payload(jwt) self.sub = self.payload["sub"] self.iss = self.payload["iss"] self.exp = self.payload["exp"] @@ -300,11 +308,15 @@ def validate_by_itself(self) -> bool: f"{e}" ) - if self.header.get("kid") not in self.kids: + _kid = self.header.get("kid") + + if _kid not in self.kids: raise UnknownKid( - f"{self.header.get('kid')} not found in {self.jwks}") # pragma: no cover + f"{_kid} not found in {self.jwks}") # pragma: no cover + # verify signature - jwsh = JWSHelper(self.jwks[self.kids.index(self.header["kid"])]) + _jwk = find_jwk(_kid, self.jwks) + jwsh = JWSHelper(_jwk) jwsh.verify(self.jwt) self.is_valid = True return True @@ -501,8 +513,8 @@ def validate_descendant_statement(self, jwt: str) -> bool: :returns: True if is valid or False otherwise :rtype: bool """ - header = unpad_jwt_header(jwt) - payload = unpad_jwt_payload(jwt) + header = decode_jwt_header(jwt) + payload = decode_jwt_payload(jwt) try: EntityConfigurationHeader(**header) @@ -520,12 +532,15 @@ def validate_descendant_statement(self, jwt: str) -> bool: f"{e}" ) - if header.get("kid") not in self.kids: + _kid = header.get("kid") + + if _kid not in self.kids: raise UnknownKid( - f"{self.header.get('kid')} not found in {self.jwks}") + f"{_kid} not found in {self.jwks}") # verify signature - jwsh = JWSHelper(self.jwks[self.kids.index(header["kid"])]) + _jwk = find_jwk(_kid, self.jwks) + jwsh = JWSHelper(_jwk) payload = jwsh.verify(jwt) self.verified_descendant_statements[payload["sub"]] = payload @@ -546,13 +561,13 @@ def validate_by_superior_statement(self, jwt: str, ec: 'EntityStatement') -> str is_valid = None payload = {} try: - payload = unpad_jwt_payload(jwt) + payload = decode_jwt_payload(jwt) ec.validate_by_itself() ec.validate_descendant_statement(jwt) _jwks = get_federation_jwks(payload) - _kids = [i.get("kid") for i in _jwks] + _jwk = find_jwk(self.header["kid"], _jwks) - jwsh = JWSHelper(_jwks[_kids.index(self.header["kid"])]) + jwsh = JWSHelper(_jwk) payload = jwsh.verify(self.jwt) is_valid = True diff --git a/pyeudiw/federation/trust_chain_validator.py b/pyeudiw/federation/trust_chain_validator.py index 9c00c53a..6426d9b8 100644 --- a/pyeudiw/federation/trust_chain_validator.py +++ b/pyeudiw/federation/trust_chain_validator.py @@ -1,7 +1,7 @@ import logging from pyeudiw.tools.utils import iat_now from pyeudiw.jwt import JWSHelper -from pyeudiw.jwt.utils import unpad_jwt_payload, unpad_jwt_header +from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header from pyeudiw.federation import is_es from pyeudiw.federation.policy import TrustChainPolicy from pyeudiw.federation.statements import ( @@ -15,27 +15,10 @@ KeyValidationError ) -logger = logging.getLogger(__name__) - - -def find_jwk(kid: str, jwks: list[dict]) -> dict: - """ - Find the JWK with the indicated kid in the jwks list. - - :param kid: the identifier of the jwk - :type kid: str - :param jwks: the list of jwks - :type jwks: list[dict] +from pyeudiw.jwk import find_jwk +from pyeudiw.jwk.exceptions import KidNotFoundError, InvalidKid - :returns: the jwk with the indicated kid or an empty dict if no jwk is found - :rtype: dict - """ - if not kid: - return {} - for jwk in jwks: - valid_jwk = jwk.get("kid", None) - if valid_jwk and kid == valid_jwk: - return jwk +logger = logging.getLogger(__name__) class StaticTrustChainValidator: @@ -141,8 +124,8 @@ def validate(self) -> bool: # inspect the entity statement kid header to know which # TA's public key to use for the validation last_element = rev_tc[0] - es_header = unpad_jwt_header(last_element) - es_payload = unpad_jwt_payload(last_element) + es_header = decode_jwt_header(last_element) + es_payload = decode_jwt_payload(last_element) ta_jwk = find_jwk( es_header.get("kid", None), self.trust_anchor_jwks @@ -169,13 +152,14 @@ def validate(self) -> bool: # validate the entire chain taking in cascade using fed_jwks # if valid -> update fed_jwks with $st for st in rev_tc[1:]: - st_header = unpad_jwt_header(st) - st_payload = unpad_jwt_payload(st) - jwk = find_jwk( - st_header.get("kid", None), fed_jwks - ) + st_header = decode_jwt_header(st) + st_payload = decode_jwt_payload(st) - if not jwk: + try: + jwk = find_jwk( + st_header.get("kid", None), fed_jwks + ) + except (KidNotFoundError, InvalidKid): return False jwsh = JWSHelper(jwk) @@ -237,7 +221,7 @@ def _update_st(self, st: str) -> str: :returns: the entity statement in form of JWT. :rtype: str """ - payload = unpad_jwt_payload(st) + payload = decode_jwt_payload(st) iss = payload['iss'] if not is_es(payload): # It's an entity configuration @@ -251,7 +235,7 @@ def _update_st(self, st: str) -> str: ) else: ec = self._retrieve_ec(iss) - ec_data = unpad_jwt_payload(ec) + ec_data = decode_jwt_payload(ec) fetch_api_url = None try: @@ -290,7 +274,7 @@ def update(self) -> bool: for st in self.static_trust_chain: jwt = self._update_st(st) - exp = unpad_jwt_payload(jwt)["exp"] + exp = decode_jwt_payload(jwt)["exp"] self.set_exp(exp) self.updated_trust_chain.append(jwt) @@ -316,18 +300,18 @@ def is_expired(self) -> int: def entity_id(self) -> str: """Get the chain's entity_id.""" chain = self.trust_chain - payload = unpad_jwt_payload(chain[0]) + payload = decode_jwt_payload(chain[0]) return payload["iss"] @property def final_metadata(self) -> dict: """Apply the metadata and returns the final metadata.""" anchor = self.static_trust_chain[-1] - es_anchor_payload = unpad_jwt_payload(anchor) + es_anchor_payload = decode_jwt_payload(anchor) policy = es_anchor_payload.get("metadata_policy", {}) leaf = self.static_trust_chain[0] - es_leaf_payload = unpad_jwt_payload(leaf) + es_leaf_payload = decode_jwt_payload(leaf) return TrustChainPolicy().apply_policy(es_leaf_payload["metadata"], policy) diff --git a/pyeudiw/jwk/__init__.py b/pyeudiw/jwk/__init__.py index 1834a15f..0106d2e4 100644 --- a/pyeudiw/jwk/__init__.py +++ b/pyeudiw/jwk/__init__.py @@ -6,13 +6,18 @@ from cryptojwt.jwk.jwk import key_from_jwk_dict from cryptojwt.jwk.rsa import new_rsa_key +from .exceptions import InvalidKid, KidNotFoundError + KEY_TYPES_FUNC = dict( EC=new_ec_key, RSA=new_rsa_key ) - class JWK(): + """ + The class representing a JWK istance + """ + def __init__( self, key: Union[dict, None] = None, @@ -20,7 +25,19 @@ def __init__( hash_func: str = 'SHA-256', ec_crv: str = "P-256" ) -> None: + """ + Creates an instance of JWK. + :param key: An optional key in dict form. + If no key is provided a randomic key will be generated. + :type key: Union[dict, None] + :param key_type: a string that represents the key type. Can be EC or RSA. + :type key_type: str + :param hash_func: a string that represents the hash function to use with the instance. + :type hash_func: str + :param ec_crv: a string that represents the curve to use with the instance. + :type ec_crv: str + """ kwargs = {} self.kid = "" @@ -46,10 +63,22 @@ def __init__( self.public_key = self.key.serialize() self.public_key['kid'] = self.jwk["kid"] - def as_json(self): + def as_json(self) -> str: + """ + Returns the JWK in format of json string. + + :returns: A json string that represents the key. + :rtype: str + """ return json.dumps(self.jwk) - def export_private_pem(self): + def export_private_pem(self) -> str: + """ + Returns the JWK in format of a private pem certificte. + + :returns: A private pem certificate that represents the key. + :rtype: str + """ _k = key_from_jwk_dict(self.jwk) pk = _k.private_key() pem = pk.private_bytes( @@ -59,7 +88,13 @@ def export_private_pem(self): ) return pem.decode() - def export_public_pem(self): + def export_public_pem(self) -> str: + """ + Returns the JWK in format of a public pem certificte. + + :returns: A public pem certificate that represents the key. + :rtype: str + """ _k = key_from_jwk_dict(self.jwk) pk = _k.public_key() cert = pk.public_bytes( @@ -68,9 +103,41 @@ def export_public_pem(self): ) return cert.decode() - def as_dict(self): + def as_dict(self) -> dict: + """ + Returns the JWK in format of dict. + + :returns: The key in form of dict. + :rtype: dict + """ return self.jwk def __repr__(self): # private part! return self.as_json() + +def find_jwk(kid: str, jwks: list[dict], as_dict: bool=True) -> dict | JWK: + """ + Find the JWK with the indicated kid in the jwks list. + + :param kid: the identifier of the jwk + :type kid: str + :param jwks: the list of jwks + :type jwks: list[dict] + :param as_dict: if True the return type will be a dict, JWK otherwise. + :type as_dict: bool + + :raises InvalidKid: if kid is None. + :raises KidNotFoundError: if kid is not in jwks list. + + :returns: the jwk with the indicated kid or an empty dict if no jwk is found + :rtype: dict | JWK + """ + if not kid: + raise InvalidKid("Kid cannot be empty") + for jwk in jwks: + valid_jwk = jwk.get("kid", None) + if valid_jwk and kid == valid_jwk: + return jwk if as_dict else JWK(jwk) + + raise KidNotFoundError(f"Key with Kid {kid} not found") \ No newline at end of file diff --git a/pyeudiw/jwk/exceptions.py b/pyeudiw/jwk/exceptions.py index 7f05a493..b3a84613 100644 --- a/pyeudiw/jwk/exceptions.py +++ b/pyeudiw/jwk/exceptions.py @@ -5,6 +5,8 @@ class KidError(Exception): class KidNotFoundError(Exception): pass +class InvalidKid(Exception): + pass class JwkError(Exception): pass diff --git a/pyeudiw/jwt/__init__.py b/pyeudiw/jwt/__init__.py index ed9661e3..ed7d4a13 100644 --- a/pyeudiw/jwt/__init__.py +++ b/pyeudiw/jwt/__init__.py @@ -1,6 +1,6 @@ import binascii import json -from typing import Union +from typing import Union, Any import cryptojwt from cryptojwt.exception import VerificationError @@ -12,7 +12,7 @@ from pyeudiw.jwk import JWK from pyeudiw.jwk.exceptions import KidError -from pyeudiw.jwt.utils import unpad_jwt_header +from pyeudiw.jwt.utils import decode_jwt_header DEFAULT_HASH_FUNC = "SHA-256" @@ -38,10 +38,32 @@ class JWEHelper(): - def __init__(self, jwk: JWK): + """ + The helper class for work with JWEs. + """ + def __init__(self, jwk: Union[JWK, dict]): + """ + Creates an instance of JWEHelper. + + :param jwk: The JWK used to crypt and encrypt the content of JWE. + :type jwk: JWK + """ self.jwk = jwk + if isinstance(jwk, dict): + self.jwk = JWK(jwk) + self.alg = DEFAULT_SIG_KTY_MAP[self.jwk.key.kty] def encrypt(self, plain_dict: Union[dict, str, int, None], **kwargs) -> str: + """ + Generate a encrypted JWE string. + + :param plain_dict: The payload of JWE. + :type plain_dict: Union[dict, str, int, None] + :param kwargs: Other optional fields to generate the JWE. + + :returns: A string that represents the JWE. + :rtype: str + """ _key = key_from_jwk_dict(self.jwk.as_dict()) if isinstance(_key, cryptojwt.jwk.rsa.RSAKey): @@ -75,8 +97,17 @@ def encrypt(self, plain_dict: Union[dict, str, int, None], **kwargs) -> str: return _keyobj.encrypt(key=_key.public_key()) def decrypt(self, jwe: str) -> dict: + """ + Generate a dict containing the content of decrypted JWE string. + + :param jwe: A string representing the jwe. + :type jwe: str + + :returns: A dict that represents the payload of decrypted JWE. + :rtype: dict + """ try: - jwe_header = unpad_jwt_header(jwe) + jwe_header = decode_jwt_header(jwe) except (binascii.Error, Exception) as e: raise VerificationError("The JWT is not valid") @@ -97,7 +128,16 @@ def decrypt(self, jwe: str) -> dict: class JWSHelper: + """ + The helper class for work with JWEs. + """ def __init__(self, jwk: Union[JWK, dict]): + """ + Creates an instance of JWSHelper. + + :param jwk: The JWK used to sign and verify the content of JWS. + :type jwk: Union[JWK, dict] + """ self.jwk = jwk if isinstance(jwk, dict): self.jwk = JWK(jwk) @@ -109,7 +149,19 @@ def sign( protected: dict = {}, **kwargs ) -> str: - + """ + Generate a encrypted JWS string. + + :param plain_dict: The payload of JWS. + :type plain_dict: Union[dict, str, int, None] + :param protected: a dict containing all the values + to include in the protected header. + :type protected: dict + :param kwargs: Other optional fields to generate the JWE. + + :returns: A string that represents the JWS. + :rtype: str + """ _key = key_from_jwk_dict(self.jwk.as_dict()) _payload: str | int | bytes = "" @@ -126,10 +178,20 @@ def sign( _signer = JWSec(_payload, alg=self.alg, **kwargs) return _signer.sign_compact([_key], protected=protected, **kwargs) - def verify(self, jws: str, **kwargs): + def verify(self, jws: str, **kwargs) -> (str | Any | bytes): + """ + Verify a JWS string. + + :param jws: A string representing the jwe. + :type jws: str + :param kwargs: Other optional fields to generate the JWE. + + :returns: A string that represents the payload of JWS. + :rtype: str + """ _key = key_from_jwk_dict(self.jwk.as_dict()) _jwk_dict = self.jwk.as_dict() - _head = unpad_jwt_header(jws) + _head = decode_jwt_header(jws) if _head.get("kid"): if _head["kid"] != _jwk_dict["kid"]: # pragma: no cover diff --git a/pyeudiw/jwt/exceptions.py b/pyeudiw/jwt/exceptions.py index 2e059616..cec4a78c 100644 --- a/pyeudiw/jwt/exceptions.py +++ b/pyeudiw/jwt/exceptions.py @@ -1,2 +1,5 @@ class JWEDecryptionError(Exception): pass + +class JWTInvalidElementPosition(Exception): + pass \ No newline at end of file diff --git a/pyeudiw/jwt/utils.py b/pyeudiw/jwt/utils.py index ca89ca50..b44e8a6e 100644 --- a/pyeudiw/jwt/utils.py +++ b/pyeudiw/jwt/utils.py @@ -2,41 +2,100 @@ import json import re +from typing import Dict +from pyeudiw.jwt.exceptions import JWTInvalidElementPosition +from pyeudiw.jwk import find_jwk + # JWT_REGEXP = r"^(([-A-Za-z0-9\=_])*\.([-A-Za-z0-9\=_])*\.([-A-Za-z0-9\=_])*)$" JWT_REGEXP = r'^[\w\-]+\.[\w\-]+\.[\w\-]+' -def unpad_jwt_element(jwt: str, position: int) -> dict: +def decode_jwt_element(jwt: str, position: int) -> dict: + """ + Decodes the element in a determinated position. + + :param jwt: a string that represents the jwt. + :type jwt: str + :param position: the position of segment to unpad. + :type position: int + + :raises JWTInvalidElementPosition: If the JWT element position is greather then one or less of 0 + + :returns: a dict with the content of the decoded section. + :rtype: dict + """ + if position > 1 or position < 0: + raise JWTInvalidElementPosition(f"JWT has no element in position {position}") + if isinstance(jwt, bytes): jwt = jwt.decode() + b = jwt.split(".")[position] padded = f"{b}{'=' * divmod(len(b), 4)[1]}" data = json.loads(base64.urlsafe_b64decode(padded)) return data -def unpad_jwt_header(jwt: str) -> dict: - return unpad_jwt_element(jwt, position=0) +def decode_jwt_header(jwt: str) -> dict: + """ + Decodes the jwt header. + + :param jwt: a string that represents the jwt. + :type jwt: str + + :returns: a dict with the content of the decoded header. + :rtype: dict + """ + return decode_jwt_element(jwt, position=0) + + +def decode_jwt_payload(jwt: str) -> dict: + """ + Decodes the jwt payload. + :param jwt: a string that represents the jwt. + :type jwt: str -def unpad_jwt_payload(jwt: str) -> dict: - return unpad_jwt_element(jwt, position=1) + :returns: a dict with the content of the decoded payload. + :rtype: dict + """ + return decode_jwt_element(jwt, position=1) -def get_jwk_from_jwt(jwt: str, provider_jwks: dict) -> dict: +def get_jwk_from_jwt(jwt: str, provider_jwks: Dict[str, dict]) -> dict: """ - docs here + Find the JWK inside the provider JWKs with the kid + specified in jwt header. + + :param jwt: a string that represents the jwt. + :type jwt: str + :param provider_jwks: a dictionary that contains one or more JWKs with the KID as the key. + :type provider_jwks: Dict[str, dict] + + :raises InvalidKid: if kid is None. + :raises KidNotFoundError: if kid is not in jwks list. + + :returns: the jwk as dict. + :rtype: dict """ - head = unpad_jwt_header(jwt) + head = decode_jwt_header(jwt) kid = head["kid"] if isinstance(provider_jwks, dict) and provider_jwks.get('keys'): provider_jwks = provider_jwks['keys'] - for jwk in provider_jwks: - if jwk["kid"] == kid: - return jwk - return {} + + return find_jwk(kid, provider_jwks) def is_jwt_format(jwt: str) -> bool: + """ + Check if a string is in JWT format. + + :param jwt: a string that represents the jwt. + :type jwt: str + + :returns: True if the string is a JWT, False otherwise. + :rtype: bool + """ + res = re.match(JWT_REGEXP, jwt) return bool(res) diff --git a/pyeudiw/oauth2/dpop/__init__.py b/pyeudiw/oauth2/dpop/__init__.py index 2672df96..986efd44 100644 --- a/pyeudiw/oauth2/dpop/__init__.py +++ b/pyeudiw/oauth2/dpop/__init__.py @@ -11,7 +11,7 @@ ) from pyeudiw.jwk.exceptions import KidError from pyeudiw.jwt import JWSHelper -from pyeudiw.jwt.utils import unpad_jwt_header, unpad_jwt_payload +from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload from pyeudiw.oauth2.dpop.schema import ( DPoPTokenHeaderSchema, DPoPTokenPayloadSchema @@ -22,7 +22,20 @@ class DPoPIssuer: + """ + Helper class for generate DPoP proofs. + """ def __init__(self, htu: str, token: str, private_jwk: dict): + """ + Generates an instance of DPoPIssuer. + + :param htu: a string representing the htu value. + :type htu: str + :param token: a string representing the token value. + :type token: str + :param private_jwk: a dict representing the private JWK of DPoP. + :type private_jwk: dict + """ self.token = token self.private_jwk = private_jwk self.signer = JWSHelper(private_jwk) @@ -30,6 +43,7 @@ def __init__(self, htu: str, token: str, private_jwk: dict): @property def proof(self): + """Returns the proof.""" data = { "jti": str(uuid.uuid4()), "htm": "GET", @@ -48,6 +62,10 @@ def proof(self): class DPoPVerifier: + """ + Helper class for validate DPoP proofs. + """ + dpop_header_prefix = 'DPoP ' def __init__( @@ -56,6 +74,19 @@ def __init__( http_header_authz: str, http_header_dpop: str, ): + """ + Generate an instance of DPoPVerifier. + + :param public_jwk: a dict representing the public JWK of DPoP. + :type public_jwk: dict + :param http_header_authz: a string representing the authz value. + :type http_header_authz: str + :param http_header_dpop: a string representing the DPoP value. + :type http_header_dpop: str + + :raises ValueError: if DPoP proof is not a valid JWT + + """ self.public_jwk = public_jwk self.dpop_token = ( http_header_authz.replace(self.dpop_header_prefix, '') @@ -72,7 +103,7 @@ def __init__( ) # If the jwt is invalid, this will raise an exception try: - unpad_jwt_header(http_header_dpop) + decode_jwt_header(http_header_dpop) except UnicodeDecodeError as e: logger.error( "DPoP proof validation error, " @@ -89,9 +120,20 @@ def __init__( @property def is_valid(self) -> bool: + """Returns True if DPoP is valid.""" return self.validate() def validate(self) -> bool: + """ + Validates the content of DPoP. + + :raises InvalidDPoPKid: if the kid of DPoP is invalid. + :raises InvalidDPoPAth: if the header's JWK is different from public_jwk's one. + + :returns: True if the validation is correctly executed, False otherwise + :rtype: bool + """ + jws_verifier = JWSHelper(self.public_jwk) try: dpop_valid = jws_verifier.verify(self.proof) @@ -108,7 +150,7 @@ def validate(self) -> bool: f"{e.__class__.__name__}: {e}" ) - header = unpad_jwt_header(self.proof) + header = decode_jwt_header(self.proof) DPoPTokenHeaderSchema(**header) if header['jwk'] != self.public_jwk: @@ -118,7 +160,7 @@ def validate(self) -> bool: f"{header['jwk']} != {self.public_jwk}" )) - payload = unpad_jwt_payload(self.proof) + payload = decode_jwt_payload(self.proof) DPoPTokenPayloadSchema(**payload) _ath = hashlib.sha256(self.dpop_token.encode()) diff --git a/pyeudiw/openid4vp/direct_post_response.py b/pyeudiw/openid4vp/direct_post_response.py index 5099256e..8a76c9c1 100644 --- a/pyeudiw/openid4vp/direct_post_response.py +++ b/pyeudiw/openid4vp/direct_post_response.py @@ -3,7 +3,7 @@ from pyeudiw.jwt import JWEHelper from pyeudiw.jwt.exceptions import JWEDecryptionError from pyeudiw.jwk.exceptions import KidNotFoundError -from pyeudiw.jwt.utils import unpad_jwt_header +from pyeudiw.jwt.utils import decode_jwt_header from pyeudiw.openid4vp.exceptions import ( VPNotFound, VPInvalidNonce, @@ -16,7 +16,7 @@ class DirectPostResponse: def __init__(self, jwt: str, jwks_by_kids: dict, nonce: str = ""): - self.headers = unpad_jwt_header(jwt) + self.headers = decode_jwt_header(jwt) self.jwks_by_kids = jwks_by_kids self.jwt = jwt self.nonce = nonce diff --git a/pyeudiw/openid4vp/vp.py b/pyeudiw/openid4vp/vp.py index 69ab61ad..fcfb385b 100644 --- a/pyeudiw/openid4vp/vp.py +++ b/pyeudiw/openid4vp/vp.py @@ -1,5 +1,5 @@ -from pyeudiw.jwt.utils import unpad_jwt_payload, unpad_jwt_header +from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header from pyeudiw.openid4vp.vp_sd_jwt import VpSdJwt @@ -7,9 +7,9 @@ class Vp(VpSdJwt): def __init__(self, jwt: str): # TODO: what if the credential is not a JWT? - self.headers = unpad_jwt_header(jwt) + self.headers = decode_jwt_header(jwt) self.jwt = jwt - self.payload = unpad_jwt_payload(jwt) + self.payload = decode_jwt_payload(jwt) self.credential_headers: dict = {} self.credential_payload: dict = {} @@ -35,8 +35,8 @@ def credential_issuer(self): def parse_digital_credential(self): _typ = self._detect_vp_type() if _typ == 'jwt': - self.credential_headers = unpad_jwt_header(self.payload['vp']) - self.credential_payload = unpad_jwt_payload(self.payload['vp']) + self.credential_headers = decode_jwt_header(self.payload['vp']) + self.credential_payload = decode_jwt_payload(self.payload['vp']) else: raise NotImplementedError( f"VP Digital credentials type not implemented yet: {_typ}" diff --git a/pyeudiw/satosa/dpop.py b/pyeudiw/satosa/dpop.py index c0775cb9..366bb45b 100644 --- a/pyeudiw/satosa/dpop.py +++ b/pyeudiw/satosa/dpop.py @@ -3,7 +3,7 @@ from typing import Union -from pyeudiw.jwt.utils import unpad_jwt_header, unpad_jwt_payload +from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload from pyeudiw.oauth2.dpop import DPoPVerifier from pyeudiw.openid4vp.schemas.wallet_instance_attestation import WalletInstanceAttestationPayload, \ WalletInstanceAttestationHeader @@ -25,8 +25,8 @@ def _request_endpoint_dpop(self, context, *args) -> Union[JsonResponse, None]: # take WIA dpop_jws = context.http_headers['HTTP_AUTHORIZATION'].split()[-1] - _head = unpad_jwt_header(dpop_jws) - wia = unpad_jwt_payload(dpop_jws) + _head = decode_jwt_header(dpop_jws) + wia = decode_jwt_payload(dpop_jws) self._log( context, diff --git a/pyeudiw/satosa/trust.py b/pyeudiw/satosa/trust.py index 73a066f8..5842c78a 100644 --- a/pyeudiw/satosa/trust.py +++ b/pyeudiw/satosa/trust.py @@ -8,8 +8,7 @@ from pyeudiw.jwk import JWK from pyeudiw.jwt import JWSHelper -from pyeudiw.jwt.utils import unpad_jwt_header -from pyeudiw.federation.trust_chain_builder import TrustChainBuilder +from pyeudiw.jwt.utils import decode_jwt_header from pyeudiw.satosa.exceptions import ( NotTrustedFederationError, DiscoveryFailedError ) @@ -166,7 +165,7 @@ def _validate_trust(self, context: Context, jws: str) -> TrustEvaluationHelper: ) ) - headers = unpad_jwt_header(jws) + headers = decode_jwt_header(jws) trust_eval = TrustEvaluationHelper( self.db_engine, httpc_params=self.config['network']['httpc_params'], diff --git a/pyeudiw/sd_jwt/__init__.py b/pyeudiw/sd_jwt/__init__.py index 14cd1d89..0fbe0e53 100644 --- a/pyeudiw/sd_jwt/__init__.py +++ b/pyeudiw/sd_jwt/__init__.py @@ -13,7 +13,7 @@ from pyeudiw.jwk import JWK from pyeudiw.jwt import DEFAULT_SIG_KTY_MAP -from pyeudiw.jwt.utils import unpad_jwt_payload +from pyeudiw.jwt.utils import decode_jwt_payload from pyeudiw.tools.utils import exp_from_now, iat_now from jwcrypto.jws import JWS @@ -167,7 +167,7 @@ def verify_sd_jwt( settings.update( { - "issuer": unpad_jwt_payload(sd_jwt_presentation)["iss"] + "issuer": decode_jwt_payload(sd_jwt_presentation)["iss"] } ) adapted_keys = { diff --git a/pyeudiw/tests/oauth2/test_dpop.py b/pyeudiw/tests/oauth2/test_dpop.py index f65a8a10..b5a82e2b 100644 --- a/pyeudiw/tests/oauth2/test_dpop.py +++ b/pyeudiw/tests/oauth2/test_dpop.py @@ -4,7 +4,7 @@ from pyeudiw.jwk import JWK from pyeudiw.jwt import JWSHelper -from pyeudiw.jwt.utils import unpad_jwt_header, unpad_jwt_payload +from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload from pyeudiw.oauth2.dpop import DPoPIssuer, DPoPVerifier from pyeudiw.oauth2.dpop.exceptions import InvalidDPoPKid from pyeudiw.tools.utils import iat_now @@ -67,7 +67,7 @@ def wia_jws(jwshelper): def test_create_validate_dpop_http_headers(wia_jws, private_jwk=PRIVATE_JWK): # create - header = unpad_jwt_header(wia_jws) + header = decode_jwt_header(wia_jws) assert header assert isinstance(header["trust_chain"], list) assert isinstance(header["x5c"], list) @@ -82,13 +82,13 @@ def test_create_validate_dpop_http_headers(wia_jws, private_jwk=PRIVATE_JWK): proof = new_dpop.proof assert proof - header = unpad_jwt_header(proof) + header = decode_jwt_header(proof) assert header["typ"] == "dpop+jwt" assert header["alg"] assert "mac" not in str(header["alg"]).lower() assert "d" not in header["jwk"] - payload = unpad_jwt_payload(proof) + payload = decode_jwt_payload(proof) assert payload["ath"] == base64.urlsafe_b64encode( hashlib.sha256(wia_jws.encode() ).digest()).rstrip(b'=').decode() diff --git a/pyeudiw/tests/satosa/test_backend.py b/pyeudiw/tests/satosa/test_backend.py index 15b9540e..cd507927 100644 --- a/pyeudiw/tests/satosa/test_backend.py +++ b/pyeudiw/tests/satosa/test_backend.py @@ -13,9 +13,9 @@ from sd_jwt.holder import SDJWTHolder from pyeudiw.jwk import JWK -from pyeudiw.jwt import JWEHelper, JWSHelper, unpad_jwt_header, DEFAULT_SIG_KTY_MAP +from pyeudiw.jwt import JWEHelper, JWSHelper, decode_jwt_header, DEFAULT_SIG_KTY_MAP from cryptojwt.jws.jws import JWS -from pyeudiw.jwt.utils import unpad_jwt_payload +from pyeudiw.jwt.utils import decode_jwt_payload from pyeudiw.oauth2.dpop import DPoPIssuer from pyeudiw.satosa.backend import OpenID4VPBackend from pyeudiw.sd_jwt import ( @@ -522,8 +522,8 @@ def test_request_endpoint(self, context): msg = json.loads(request_endpoint.message) assert msg["response"] - header = unpad_jwt_header(msg["response"]) - payload = unpad_jwt_payload(msg["response"]) + header = decode_jwt_header(msg["response"]) + payload = decode_jwt_payload(msg["response"]) assert header["alg"] assert header["kid"] assert payload["scope"] == " ".join(CONFIG["authorization"]["scopes"]) diff --git a/pyeudiw/tests/test_jwt.py b/pyeudiw/tests/test_jwt.py index 0e8771a9..d0982098 100644 --- a/pyeudiw/tests/test_jwt.py +++ b/pyeudiw/tests/test_jwt.py @@ -3,7 +3,7 @@ from pyeudiw.jwk import JWK from pyeudiw.jwt import (DEFAULT_ENC_ALG_MAP, DEFAULT_ENC_ENC_MAP, JWEHelper, JWSHelper) -from pyeudiw.jwt.utils import unpad_jwt_header +from pyeudiw.jwt.utils import decode_jwt_header JWKs_EC = [ (JWK(key_type="EC"), {"key": "value"}), @@ -24,11 +24,11 @@ @pytest.mark.parametrize("jwk, payload", JWKs_RSA) -def test_unpad_jwt_header(jwk, payload): +def test_decode_jwt_header(jwk, payload): jwe_helper = JWEHelper(jwk) jwe = jwe_helper.encrypt(payload) assert jwe - header = unpad_jwt_header(jwe) + header = decode_jwt_header(jwe) assert header assert header["alg"] == DEFAULT_ENC_ALG_MAP[jwk.jwk["kty"]] assert header["enc"] == DEFAULT_ENC_ENC_MAP[jwk.jwk["kty"]] diff --git a/pyeudiw/trust/__init__.py b/pyeudiw/trust/__init__.py index ef8cd089..82c1bad4 100644 --- a/pyeudiw/trust/__init__.py +++ b/pyeudiw/trust/__init__.py @@ -6,7 +6,7 @@ from pyeudiw.federation.exceptions import ProtocolMetadataNotFound from pyeudiw.satosa.exceptions import DiscoveryFailedError from pyeudiw.storage.db_engine import DBEngine -from pyeudiw.jwt.utils import unpad_jwt_payload, is_jwt_format +from pyeudiw.jwt.utils import decode_jwt_payload, is_jwt_format from pyeudiw.x509.verify import verify_x509_anchor, get_issuer_from_x5c, is_der_format from pyeudiw.storage.exceptions import EntryNotFound @@ -74,7 +74,7 @@ def _update_chain(self, entity_id: str | None = None, exp: datetime | None = Non self.trust_chain = trust_chain def _handle_federation_chain(self): - _first_statement = unpad_jwt_payload(self.trust_chain[-1]) + _first_statement = decode_jwt_payload(self.trust_chain[-1]) trust_anchor_eid = self.trust_anchor or _first_statement.get( 'iss', None) @@ -92,7 +92,7 @@ def _handle_federation_chain(self): "a recognizable Trust Anchor." ) - decoded_ec = unpad_jwt_payload( + decoded_ec = decode_jwt_payload( trust_anchor['federation']['entity_configuration'] ) jwks = decoded_ec.get('jwks', {}).get('keys', []) @@ -209,7 +209,7 @@ def get_final_metadata(self, metadata_type: str, policies: list[dict]) -> dict: for policy in policies: policy_acc = combine(policy, policy_acc) - self.final_metadata = unpad_jwt_payload(self.trust_chain[0]) + self.final_metadata = decode_jwt_payload(self.trust_chain[0]) try: # TODO: there are some cases where the jwks are taken from a uri ... diff --git a/pyeudiw/trust/trust_chain.py b/pyeudiw/trust/trust_chain.py index fc074a64..bc9c0820 100644 --- a/pyeudiw/trust/trust_chain.py +++ b/pyeudiw/trust/trust_chain.py @@ -4,7 +4,7 @@ from cryptojwt import KeyJar from cryptojwt.jwt import utc_time_sans_frac -from pyeudiw.jwt.utils import unpad_jwt_payload +from pyeudiw.jwt.utils import decode_jwt_payload __author__ = "Roland Hedberg" __license__ = "Apache 2.0"