Skip to content

Commit

Permalink
Feature/sd jwt issuer verify iat exp (#321)
Browse files Browse the repository at this point in the history
* fix: remove duplicated code

* fix: redundant controls

* feat: add keys selection for verify issuer token

* fix: duplicates

* fix: dead code removed

* feat: JWT and SD-JWT check iat/exp and nbf
#311

* fix: move multiple function call to single variable

* Update pyeudiw/jwt/jws_helper.py

Co-authored-by: Giuseppe De Marco <[email protected]>

* refactor: consolidate lifetime validation logic into a reusable function

* refactor: Add `LifetimeException` and centralized JWT timestamp validation

---------

Co-authored-by: Giuseppe De Marco <[email protected]>
  • Loading branch information
LadyCodesItBetter and peppelinux authored Jan 10, 2025
1 parent e593ca3 commit 68246ab
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 73 deletions.
29 changes: 28 additions & 1 deletion pyeudiw/jwt/helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json

from pydantic import ValidationError
from pyeudiw.jwk import JWK
from pyeudiw.jwk.parse import parse_key_from_x5c

Expand Down Expand Up @@ -96,7 +98,32 @@ def is_payload_expired(token_payload: dict) -> bool:
return True
return False


def is_jwt_expired(token: str) -> bool:
payload = decode_jwt_payload(token)
return is_payload_expired(payload)

class LifetimeException(ValidationError):
"""Exception raised for errors related to lifetime validation."""
pass

def validate_jwt_timestamps_claims(payload: dict) -> None:
"""
Validates the 'iat', 'exp', and 'nbf' claims in a JWT payload.
:param payload: The decoded JWT payload.
:type payload: dict
:raises ValueError: If any of the claims are invalid.
"""
current_time = iat_now()

if 'iat' in payload:
if payload['iat'] > current_time:
raise LifetimeException("Future issue time, token is invalid.")

if 'exp' in payload:
if payload['exp'] <= current_time:
raise LifetimeException("Token has expired.")

if 'nbf' in payload:
if payload['nbf'] > current_time:
raise LifetimeException("Token not yet valid.")
100 changes: 38 additions & 62 deletions pyeudiw/jwt/jws_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import binascii
from copy import deepcopy
import datetime
import json
import logging
from typing import Any, Literal, Union
Expand All @@ -10,7 +11,7 @@
from pyeudiw.jwk.exceptions import KidError
from pyeudiw.jwk.jwks import find_jwk_by_kid, find_jwk_by_thumbprint
from pyeudiw.jwt.exceptions import JWEEncryptionError, JWSSigningError, JWSVerificationError
from pyeudiw.jwt.helper import JWHelperInterface, find_self_contained_key, is_payload_expired, serialize_payload
from pyeudiw.jwt.helper import JWHelperInterface, find_self_contained_key, is_payload_expired, serialize_payload, validate_jwt_timestamps_claims

from pyeudiw.jwk import JWK
from pyeudiw.jwt.utils import decode_jwt_header
Expand Down Expand Up @@ -98,10 +99,6 @@ def sign(
protected = {}
if unprotected is None:
unprotected = {}

# Add SD-JWT claims if the payload matches the criteria
if isinstance(plain_dict, dict) and self._is_sd_jwt_payload(plain_dict):
plain_dict = self._add_sd_jwt_claims(plain_dict)

# Select the signing key
signing_key = self._select_signing_key((protected, unprotected), signing_kid) # TODO: check that singing key is either private or symmetric
Expand All @@ -110,7 +107,7 @@ def sign(
header_kid = protected.get("kid")
signer_kid = signing_key.get("kid")
if header_kid and signer_kid and (header_kid != signer_kid):
raise JWSSigningError(f"Token header contains kid {header_kid}, which does not match the signing key kid {signer_kid}.")
raise JWSSigningError(f"token header contains a kid {header_kid} that does not match the signing key kid {signer_kid}")

payload = serialize_payload(plain_dict)

Expand All @@ -120,7 +117,7 @@ def sign(

# Add "typ" header if not present
if "typ" not in protected:
protected["typ"] = "sd-jwt" if self._is_sd_jwt_payload(plain_dict) else "JWT"
protected["typ"] = "sd-jwt" if self.is_sd_jwt(plain_dict) else "JWT"

# Include the signing key's kid in the header if required
if kid_in_header and signer_kid:
Expand All @@ -134,14 +131,21 @@ def sign(
signing_key.pop("kid", None)

signer = JWS(payload, alg=signing_alg)
keys = [key_from_jwk_dict(signing_key)]

if serialization_format == "compact":
try:
signed = signer.sign_compact([key_from_jwk_dict(signing_key)], protected=protected, **kwargs)
signed = signer.sign_compact(
keys, protected=protected, **kwargs
)
return signed
except Exception as e:
raise JWSSigningError("Signing error: error in step", e)
return signer.sign_json(keys=[key_from_jwk_dict(signing_key)], headers=[(protected, unprotected)], flatten=True)
return signer.sign_json(
keys=keys,
headers=[(protected, unprotected)],
flatten=True,
)

def _select_signing_key(self, headers: tuple[dict, dict], signing_kid: str = "") -> dict:
if len(self.jwks) == 0:
Expand Down Expand Up @@ -223,32 +227,18 @@ def verify(self, jwt: str) -> (str | Any | bytes):
"unexpected verification state: found a valid verifying key,"
f"but its kid {obtained_kid} does not match token header kid {expected_kid}")
)

verifier = JWS(alg=header["alg"])
msg = verifier.verify_compact(jwt, [key_from_jwk_dict(verifying_key)])

# Verify the JWS compact signature
verifier = JWS(alg=header["alg"])
msg = verifier.verify_compact(jwt, [key_from_jwk_dict(verifying_key)])

# Handle the payload
if isinstance(msg, (str, bytes)):
try:
# Try to interpret as JSON
decoded_payload = json.loads(msg)
except json.JSONDecodeError:
# If not JSON, assume it's a simple string (non-SD-JWT)
decoded_payload = msg
elif isinstance(msg, dict):
decoded_payload = msg
else:
raise JWSVerificationError("Unexpected type for the JWS payload.")
msg: dict = verifier.verify_compact(jwt, [key_from_jwk_dict(verifying_key)])

# Perform SD-JWT specific validations if applicable
if self._is_sd_jwt_payload(decoded_payload):
self._validate_sd_jwt(decoded_payload)
# Validate JWT claims
try:
validate_jwt_timestamps_claims(msg)
except ValueError as e:
raise JWSVerificationError(f"Invalid JWT claims: {e}")

return decoded_payload
return msg

def _select_verifying_key(self, header: dict) -> dict | None:
available_keys = [key.to_dict() for key in self.jwks]
Expand All @@ -270,40 +260,26 @@ def _select_verifying_key(self, header: dict) -> dict | None:
if len(self.jwks) == 1:
return self.jwks[0].to_dict()
return None

def _is_sd_jwt_payload(self, payload: dict) -> bool:
"""
Determines if the payload corresponds to an SD-JWT.
:param payload: The payload to inspect.
:returns: True if the payload contains SD-JWT-specific claims, False otherwise.
"""
if not isinstance(payload, dict):
return False
return payload.get("typ") == "sd-jwt"

def _add_sd_jwt_claims(self, payload: dict) -> dict:
"""
Adds SD-JWT specific claims to the payload.
:param payload: The original payload.
:returns: The payload with added SD-JWT claims.
"""
payload["iat"] = payload.get("iat", iat_now())
payload["typ"] = "sd-jwt"
return payload

def _validate_sd_jwt(self, payload: dict) -> None:
def is_sd_jwt(self, token: str) -> bool:
"""
Validates an SD-JWT payload.
Determines if the provided JWT is an SD-JWT.
:param payload: The payload to validate.
:raises JWSVerificationError: If the payload is invalid.
:param token: The JWT token to inspect.
:type token: str
:returns: True if the token is an SD-JWT, False otherwise.
:rtype: bool
"""
if payload.get("typ") != "sd-jwt":
raise JWSVerificationError("The token is not a valid SD-JWT.")
if is_payload_expired(payload):
raise JWSVerificationError("The SD-JWT token has expired.")
if not token:
return False


try:
# Decode the JWT header to inspect the 'typ' field
header = decode_jwt_header(token)

# Check if 'typ' field exists and is equal to 'sd-jwt'
return header.get("typ") == "sd-jwt"
except Exception as e:
# Log or handle errors (optional)
logger.warning(f"Unable to determine if token is SD-JWT: {e}")
return False
5 changes: 3 additions & 2 deletions pyeudiw/satosa/default/response_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ def response_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonRe
pub_jwk = find_vp_token_key(token_parser, self.trust_evaluator)
except NoCriptographicMaterial as e:
return self._handle_400(context, f"VP parsing error: {e}")

token_issuer = token_parser.get_issuer_name()
whitelisted_keys = self.trust_evaluator.get_public_keys(token_issuer)
try:
token_verifier.verify_signature(pub_jwk)
token_verifier.verify_signature(whitelisted_keys)
except Exception as e:
return self._handle_400(context, f"VP parsing error: {e}")

Expand Down
6 changes: 4 additions & 2 deletions pyeudiw/sd_jwt/sd_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from hashlib import sha256
import json
from typing import Any, Callable, TypeVar
from pyeudiw.jwt.jws_helper import JWSHelper
import pyeudiw.sd_jwt.common as sd_jwtcommon
from pyeudiw.sd_jwt.common import SDJWTCommon

Expand Down Expand Up @@ -81,8 +82,9 @@ def get_sd_alg(self) -> str:
def has_key_binding(self) -> bool:
return self.holder_kb is not None

def verify_issuer_jwt_signature(self, key: ECKey | RSAKey | dict) -> None:
verify_jws_with_key(self.issuer_jwt.jwt, key)
def verify_issuer_jwt_signature(self, keys: list[ECKey | RSAKey | dict] | ECKey | RSAKey | dict) -> None:
jws_verifier = JWSHelper(keys)
jws_verifier.verify(self.issuer_jwt.jwt)

def verify_holder_kb_jwt(self, challenge: VerifierChallenge) -> None:
"""
Expand Down
14 changes: 9 additions & 5 deletions pyeudiw/sd_jwt/verifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging

from pyeudiw.jwt.exceptions import JWSVerificationError
from pyeudiw.jwt.helper import validate_jwt_timestamps_claims
from pyeudiw.jwt.jws_helper import JWSHelper
from pyeudiw.sd_jwt.common import (
SDJWTCommon,
Expand All @@ -10,10 +12,8 @@
KB_DIGEST_KEY,
)

from json import dumps, loads
from typing import Dict, List, Union, Callable

from cryptojwt.jwk import JWK
from cryptojwt.jwk.jwk import key_from_jwk_dict
from cryptojwt.jws.jws import JWS

Expand Down Expand Up @@ -112,15 +112,19 @@ def _verify_sd_jwt(
keys=issuer_public_key,
sigalg=sign_alg
)
# self._sd_jwt_payload = loads(parsed_input_sd_jwt.payload.decode("utf-8"))
# TODO: Check exp/nbf/iat

try:
validate_jwt_timestamps_claims(self._sd_jwt_payload)
except ValueError as e:
raise JWSVerificationError(f"Invalid JWT claims: {e}")

else:
raise ValueError(
f"Unsupported serialization format: {self._serialization_format}"
)

self._holder_public_key_payload = self._sd_jwt_payload.get("cnf", None)

def _verify_key_binding_jwt(
self,
expected_aud: Union[str, None] = None,
Expand Down
3 changes: 2 additions & 1 deletion pyeudiw/tests/jwt/test_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from cryptojwt.jwk.ec import new_ec_key

from pyeudiw.jwt.verification import verify_jws_with_key
from pyeudiw.tools.utils import iat_now

def test_is_jwt_expired():
jwk = new_ec_key('P-256')
Expand All @@ -26,7 +27,7 @@ def test_is_jwt_not_expired():

def test_verify_jws_with_key():
jwk = new_ec_key('P-256')
payload = {"exp": 1516239022}
payload = {"exp": iat_now()+5000}

helper = JWSHelper(jwk)
jws = helper.sign(payload)
Expand Down

0 comments on commit 68246ab

Please sign in to comment.