Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sd jwt issuer verify iat exp #321

Merged
29 changes: 28 additions & 1 deletion pyeudiw/jwt/helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json

from pydantic import ValidationError
from pyeudiw.jwk import JWK
from pyeudiw.jwk.parse import parse_key_from_x5c

Expand Down Expand Up @@ -96,7 +98,32 @@ def is_payload_expired(token_payload: dict) -> bool:
return True
return False


def is_jwt_expired(token: str) -> bool:
payload = decode_jwt_payload(token)
return is_payload_expired(payload)

class LifetimeException(ValidationError):
"""Exception raised for errors related to lifetime validation."""
pass

def validate_jwt_timestamps_claims(payload: dict) -> None:
"""
Validates the 'iat', 'exp', and 'nbf' claims in a JWT payload.

:param payload: The decoded JWT payload.
:type payload: dict
:raises ValueError: If any of the claims are invalid.
"""
current_time = iat_now()

if 'iat' in payload:
if payload['iat'] > current_time:
raise LifetimeException("Future issue time, token is invalid.")

if 'exp' in payload:
if payload['exp'] <= current_time:
raise LifetimeException("Token has expired.")

if 'nbf' in payload:
if payload['nbf'] > current_time:
raise LifetimeException("Token not yet valid.")
100 changes: 38 additions & 62 deletions pyeudiw/jwt/jws_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import binascii
from copy import deepcopy
import datetime
import json
import logging
from typing import Any, Literal, Union
Expand All @@ -10,7 +11,7 @@
from pyeudiw.jwk.exceptions import KidError
from pyeudiw.jwk.jwks import find_jwk_by_kid, find_jwk_by_thumbprint
from pyeudiw.jwt.exceptions import JWEEncryptionError, JWSSigningError, JWSVerificationError
from pyeudiw.jwt.helper import JWHelperInterface, find_self_contained_key, is_payload_expired, serialize_payload
from pyeudiw.jwt.helper import JWHelperInterface, find_self_contained_key, is_payload_expired, serialize_payload, validate_jwt_timestamps_claims

from pyeudiw.jwk import JWK
from pyeudiw.jwt.utils import decode_jwt_header
Expand Down Expand Up @@ -98,10 +99,6 @@ def sign(
protected = {}
if unprotected is None:
unprotected = {}

# Add SD-JWT claims if the payload matches the criteria
if isinstance(plain_dict, dict) and self._is_sd_jwt_payload(plain_dict):
plain_dict = self._add_sd_jwt_claims(plain_dict)

# Select the signing key
signing_key = self._select_signing_key((protected, unprotected), signing_kid) # TODO: check that singing key is either private or symmetric
Expand All @@ -110,7 +107,7 @@ def sign(
header_kid = protected.get("kid")
signer_kid = signing_key.get("kid")
if header_kid and signer_kid and (header_kid != signer_kid):
raise JWSSigningError(f"Token header contains kid {header_kid}, which does not match the signing key kid {signer_kid}.")
raise JWSSigningError(f"token header contains a kid {header_kid} that does not match the signing key kid {signer_kid}")

payload = serialize_payload(plain_dict)

Expand All @@ -120,7 +117,7 @@ def sign(

# Add "typ" header if not present
if "typ" not in protected:
protected["typ"] = "sd-jwt" if self._is_sd_jwt_payload(plain_dict) else "JWT"
protected["typ"] = "sd-jwt" if self.is_sd_jwt(plain_dict) else "JWT"
peppelinux marked this conversation as resolved.
Show resolved Hide resolved

# Include the signing key's kid in the header if required
if kid_in_header and signer_kid:
Expand All @@ -134,14 +131,21 @@ def sign(
signing_key.pop("kid", None)

signer = JWS(payload, alg=signing_alg)
keys = [key_from_jwk_dict(signing_key)]

if serialization_format == "compact":
try:
signed = signer.sign_compact([key_from_jwk_dict(signing_key)], protected=protected, **kwargs)
signed = signer.sign_compact(
keys, protected=protected, **kwargs
)
return signed
except Exception as e:
raise JWSSigningError("Signing error: error in step", e)
return signer.sign_json(keys=[key_from_jwk_dict(signing_key)], headers=[(protected, unprotected)], flatten=True)
return signer.sign_json(
keys=keys,
headers=[(protected, unprotected)],
flatten=True,
)

def _select_signing_key(self, headers: tuple[dict, dict], signing_kid: str = "") -> dict:
if len(self.jwks) == 0:
Expand Down Expand Up @@ -223,32 +227,18 @@ def verify(self, jwt: str) -> (str | Any | bytes):
"unexpected verification state: found a valid verifying key,"
f"but its kid {obtained_kid} does not match token header kid {expected_kid}")
)

verifier = JWS(alg=header["alg"])
msg = verifier.verify_compact(jwt, [key_from_jwk_dict(verifying_key)])

# Verify the JWS compact signature
verifier = JWS(alg=header["alg"])
msg = verifier.verify_compact(jwt, [key_from_jwk_dict(verifying_key)])

# Handle the payload
if isinstance(msg, (str, bytes)):
try:
# Try to interpret as JSON
decoded_payload = json.loads(msg)
except json.JSONDecodeError:
# If not JSON, assume it's a simple string (non-SD-JWT)
decoded_payload = msg
elif isinstance(msg, dict):
decoded_payload = msg
else:
raise JWSVerificationError("Unexpected type for the JWS payload.")
msg: dict = verifier.verify_compact(jwt, [key_from_jwk_dict(verifying_key)])

# Perform SD-JWT specific validations if applicable
if self._is_sd_jwt_payload(decoded_payload):
self._validate_sd_jwt(decoded_payload)
# Validate JWT claims
try:
validate_jwt_timestamps_claims(msg)
except ValueError as e:
raise JWSVerificationError(f"Invalid JWT claims: {e}")

return decoded_payload
return msg

def _select_verifying_key(self, header: dict) -> dict | None:
available_keys = [key.to_dict() for key in self.jwks]
Expand All @@ -270,40 +260,26 @@ def _select_verifying_key(self, header: dict) -> dict | None:
if len(self.jwks) == 1:
return self.jwks[0].to_dict()
return None

def _is_sd_jwt_payload(self, payload: dict) -> bool:
"""
Determines if the payload corresponds to an SD-JWT.

:param payload: The payload to inspect.
:returns: True if the payload contains SD-JWT-specific claims, False otherwise.
"""
if not isinstance(payload, dict):
return False
return payload.get("typ") == "sd-jwt"

def _add_sd_jwt_claims(self, payload: dict) -> dict:
"""
Adds SD-JWT specific claims to the payload.

:param payload: The original payload.
:returns: The payload with added SD-JWT claims.
"""
payload["iat"] = payload.get("iat", iat_now())
payload["typ"] = "sd-jwt"
return payload

def _validate_sd_jwt(self, payload: dict) -> None:
def is_sd_jwt(self, token: str) -> bool:
"""
Validates an SD-JWT payload.
Determines if the provided JWT is an SD-JWT.

:param payload: The payload to validate.

:raises JWSVerificationError: If the payload is invalid.
:param token: The JWT token to inspect.
:type token: str
:returns: True if the token is an SD-JWT, False otherwise.
:rtype: bool
"""
if payload.get("typ") != "sd-jwt":
raise JWSVerificationError("The token is not a valid SD-JWT.")
if is_payload_expired(payload):
raise JWSVerificationError("The SD-JWT token has expired.")
if not token:
return False


try:
# Decode the JWT header to inspect the 'typ' field
header = decode_jwt_header(token)

# Check if 'typ' field exists and is equal to 'sd-jwt'
return header.get("typ") == "sd-jwt"
except Exception as e:
# Log or handle errors (optional)
logger.warning(f"Unable to determine if token is SD-JWT: {e}")
return False
5 changes: 3 additions & 2 deletions pyeudiw/satosa/default/response_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,10 @@ def response_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonRe
pub_jwk = find_vp_token_key(token_parser, self.trust_evaluator)
except NoCriptographicMaterial as e:
return self._handle_400(context, f"VP parsing error: {e}")

token_issuer = token_parser.get_issuer_name()
whitelisted_keys = self.trust_evaluator.get_public_keys(token_issuer)
try:
token_verifier.verify_signature(pub_jwk)
token_verifier.verify_signature(whitelisted_keys)
except Exception as e:
return self._handle_400(context, f"VP parsing error: {e}")

Expand Down
6 changes: 4 additions & 2 deletions pyeudiw/sd_jwt/sd_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from hashlib import sha256
import json
from typing import Any, Callable, TypeVar
from pyeudiw.jwt.jws_helper import JWSHelper
import pyeudiw.sd_jwt.common as sd_jwtcommon
from pyeudiw.sd_jwt.common import SDJWTCommon

Expand Down Expand Up @@ -81,8 +82,9 @@ def get_sd_alg(self) -> str:
def has_key_binding(self) -> bool:
return self.holder_kb is not None

def verify_issuer_jwt_signature(self, key: ECKey | RSAKey | dict) -> None:
verify_jws_with_key(self.issuer_jwt.jwt, key)
def verify_issuer_jwt_signature(self, keys: list[ECKey | RSAKey | dict] | ECKey | RSAKey | dict) -> None:
jws_verifier = JWSHelper(keys)
jws_verifier.verify(self.issuer_jwt.jwt)

def verify_holder_kb_jwt(self, challenge: VerifierChallenge) -> None:
"""
Expand Down
14 changes: 9 additions & 5 deletions pyeudiw/sd_jwt/verifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging

from pyeudiw.jwt.exceptions import JWSVerificationError
from pyeudiw.jwt.helper import validate_jwt_timestamps_claims
from pyeudiw.jwt.jws_helper import JWSHelper
from pyeudiw.sd_jwt.common import (
SDJWTCommon,
Expand All @@ -10,10 +12,8 @@
KB_DIGEST_KEY,
)

from json import dumps, loads
from typing import Dict, List, Union, Callable

from cryptojwt.jwk import JWK
from cryptojwt.jwk.jwk import key_from_jwk_dict
from cryptojwt.jws.jws import JWS

Expand Down Expand Up @@ -112,15 +112,19 @@ def _verify_sd_jwt(
keys=issuer_public_key,
sigalg=sign_alg
)
# self._sd_jwt_payload = loads(parsed_input_sd_jwt.payload.decode("utf-8"))
# TODO: Check exp/nbf/iat

try:
validate_jwt_timestamps_claims(self._sd_jwt_payload)
except ValueError as e:
raise JWSVerificationError(f"Invalid JWT claims: {e}")

else:
raise ValueError(
f"Unsupported serialization format: {self._serialization_format}"
)

self._holder_public_key_payload = self._sd_jwt_payload.get("cnf", None)

def _verify_key_binding_jwt(
self,
expected_aud: Union[str, None] = None,
Expand Down
3 changes: 2 additions & 1 deletion pyeudiw/tests/jwt/test_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from cryptojwt.jwk.ec import new_ec_key

from pyeudiw.jwt.verification import verify_jws_with_key
from pyeudiw.tools.utils import iat_now

def test_is_jwt_expired():
jwk = new_ec_key('P-256')
Expand All @@ -26,7 +27,7 @@ def test_is_jwt_not_expired():

def test_verify_jws_with_key():
jwk = new_ec_key('P-256')
payload = {"exp": 1516239022}
payload = {"exp": iat_now()+5000}

helper = JWSHelper(jwk)
jws = helper.sign(payload)
Expand Down
Loading