Skip to content

Commit

Permalink
Move exception of s3 edge to decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
RioKnightleyNHS committed Oct 24, 2024
1 parent 2590222 commit f48058d
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 99 deletions.
86 changes: 11 additions & 75 deletions lambdas/handlers/edge_presign_handler.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,29 @@
import hashlib
import json
import logging
from urllib.parse import parse_qs

from enums.lambda_error import LambdaError
from services.edge_presign_service import EdgePresignService
from utils.decorators.handle_edge_exceptions import handle_edge_exceptions
from utils.decorators.override_error_check import override_error_check
from utils.decorators.set_audit_arg import set_request_context_for_logging
from utils.lambda_exceptions import CloudFrontEdgeException

from lambdas.utils.decorators.validate_s3_request import validate_s3_request

logger = logging.getLogger()
logger.setLevel(logging.INFO)

REQUIRED_QUERY_PARAMS = [
"X-Amz-Algorithm",
"X-Amz-Credential",
"X-Amz-Date",
"X-Amz-Expires",
"X-Amz-SignedHeaders",
"X-Amz-Signature",
"X-Amz-Security-Token",
]

REQUIRED_HEADERS = ["host", "cloudfront-viewer-country", "x-forwarded-for"]


@set_request_context_for_logging
@override_error_check
@handle_edge_exceptions
@validate_s3_request
def lambda_handler(event, context):
try:
request: dict = event["Records"][0]["cf"]["request"]
logger.info(f"CloudFront received S3 request {json.dumps(request)}")
uri: str = request["uri"]
querystring: str = request["querystring"]
headers: dict = request["headers"]
try:
origin = request.get("origin", {})
domain_name = origin["s3"]["domainName"]
except KeyError as e:
logger.error(
f"Missing origin: {str(e)}",
{"Result": LambdaError.EdgeNoOrigin.to_str()},
)
raise CloudFrontEdgeException(500, LambdaError.EdgeNoOrigin)

try:
query_params = {k: v[0] for k, v in parse_qs(querystring).items()}
except ValueError:
logger.error(f"Malformed query string: {querystring}")
raise CloudFrontEdgeException(500, LambdaError.EdgeMalformedQuery)

missing_query_params = [
param for param in REQUIRED_QUERY_PARAMS if param not in query_params
]
if missing_query_params:
logger.error(f"Missing required query parameters: {missing_query_params}")
raise CloudFrontEdgeException(500, LambdaError.EdgeMalformedQuery)

missing_headers = [
header for header in REQUIRED_HEADERS if header.lower() not in headers
]
if missing_headers:
logger.error(f"Missing required headers: {missing_headers}")
raise CloudFrontEdgeException(500, LambdaError.EdgeMalformedHeader)

presign_string = f"{uri}?{querystring}"
encoded_presign_string: str = presign_string.encode("utf-8")
presign_credentials_hash = hashlib.md5(encoded_presign_string).hexdigest()

edge_presign_service = EdgePresignService()
edge_presign_service.attempt_url_update(
uri_hash=presign_credentials_hash,
domain_name=domain_name,
)

if "authorization" in headers:
del headers["authorization"]
request = event["Records"][0]["cf"]["request"]
logger.info("Edge received S3 request")

request["headers"] = headers
request["headers"]["host"] = [{"key": "Host", "value": domain_name}]
logger.info(f"Edge Response: {json.dumps(request)}")
edge_presign_service = EdgePresignService()
request_values = edge_presign_service.extract_request_values(request)
edge_presign_service.presign_request(request_values)

return request
request = edge_presign_service.prepare_s3_response(request, request_values)

except (KeyError, IndexError) as e:
logger.error(
f"Generic Edge Malformed Error: {str(e)}",
{"Result": LambdaError.EdgeMalformed.to_str()},
)
raise CloudFrontEdgeException(500, LambdaError.EdgeMalformed)
logger.info("Edge returning S3 response")
return request
50 changes: 39 additions & 11 deletions lambdas/services/edge_presign_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import re
import hashlib

from botocore.exceptions import ClientError
from enums.lambda_error import LambdaError
Expand All @@ -12,7 +12,6 @@


class EdgePresignService:

def __init__(self):
self.dynamo_service = DynamoDBService()
self.s3_service = S3Service()
Expand Down Expand Up @@ -41,15 +40,44 @@ def attempt_url_update(self, uri_hash, domain_name) -> None:
logger.error(f"{str(e)}", {"Result": LambdaError.EdgeNoClient.to_str()})
raise CloudFrontEdgeException(400, LambdaError.EdgeNoClient)

def presign_request(self, request_values):
uri = request_values["uri"]
querystring = request_values["querystring"]
domain_name = request_values["domain_name"]

presign_string = f"{uri}?{querystring}"
encoded_presign_string = presign_string.encode("utf-8")
presign_credentials_hash = hashlib.md5(encoded_presign_string).hexdigest()

self.attempt_url_update(
uri_hash=presign_credentials_hash,
domain_name=domain_name,
)

@staticmethod
def extract_environment_from_domain(domain_name: str) -> str:
match = re.match(r"^[^-]+(?:-[^-]+)?(?=-lloyd)", domain_name)
if match:
return match.group(0)
return ""
def prepare_s3_response(request, request_values):
domain_name = request_values["domain_name"]
if "authorization" in request["headers"]:
del request["headers"]["authorization"]
request["headers"]["host"] = [{"key": "Host", "value": domain_name}]

return request

@staticmethod
def extend_table_name(base_table_name, environment) -> str:
if environment:
return f"{environment}_{base_table_name}"
return base_table_name
def extract_request_values(request) -> dict:
try:
uri = request["uri"]
querystring = request["querystring"]
headers = request["headers"]
origin = request.get("origin", {})
domain_name = origin["s3"]["domainName"]
except KeyError as e:
logger.error(f"Missing request component: {str(e)}")
raise CloudFrontEdgeException(500, LambdaError.EdgeNoOrigin)

return {
"uri": uri,
"querystring": querystring,
"headers": headers,
"domain_name": domain_name,
}
26 changes: 13 additions & 13 deletions lambdas/utils/decorators/handle_edge_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
from typing import Callable

from enums.lambda_error import LambdaError
from utils.audit_logging_setup import LoggingService
from utils.edge_response import EdgeResponse
from utils.error_response import ErrorResponse
from utils.lambda_exceptions import LambdaException
from utils.lambda_exceptions import CloudFrontEdgeException
from utils.request_context import request_context

logger = LoggingService(__name__)


def handle_edge_exceptions(lambda_func: Callable):
"""A decorator for lambda edge handler.
Catch custom Edge Exceptions or AWS ClientError that may be unhandled or raised
Usage:
@handle_edge_exceptions
def lambda_handler(event, context):
...
"""

def interceptor(event, context):
interaction_id = getattr(request_context, "request_id", None)
try:
return lambda_func(event, context)
except LambdaException as e:
except CloudFrontEdgeException as e:
logger.error(str(e))

interaction_id = getattr(request_context, "request_id", None)
return EdgeResponse(
status_code=e.status_code,
body=ErrorResponse(e.err_code, e.message, interaction_id).create(),
methods=event.get("httpMethod", "GET"),
).create_edge_response()
except Exception as e:
logger.error(f"Unhandled exception: {str(e)}")
err_code = LambdaError.EdgeMalformed.value("err_code")
message = LambdaError.EdgeMalformed.value("message")
return EdgeResponse(
status_code=500,
body=ErrorResponse(err_code, message, interaction_id).create(),
methods=event.get("httpMethod", "GET"),
).create_edge_response()

return interceptor
68 changes: 68 additions & 0 deletions lambdas/utils/decorators/validate_s3_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
from functools import wraps
from urllib.parse import parse_qs

from enums.lambda_error import LambdaError
from utils.lambda_exceptions import CloudFrontEdgeException

logger = logging.getLogger(__name__)

REQUIRED_QUERY_PARAMS = [
"X-Amz-Algorithm",
"X-Amz-Credential",
"X-Amz-Date",
"X-Amz-Expires",
"X-Amz-SignedHeaders",
"X-Amz-Signature",
"X-Amz-Security-Token",
]

REQUIRED_HEADERS = ["host", "cloudfront-viewer-country", "x-forwarded-for"]


def validate_s3_request(lambda_func):
@wraps(lambda_func)
def wrapper(event, context):
request: dict = event["Records"][0]["cf"]["request"]

if (
"uri" not in request
or "querystring" not in request
or "headers" not in request
):
logger.error(
"Missing required request components: uri, querystring, or headers."
)
raise CloudFrontEdgeException(500, LambdaError.EdgeMalformed)

origin: dict = request.get("origin", {})
if "s3" not in origin or "domainName" not in origin["s3"]:
logger.error("Missing origin domain name.")
raise CloudFrontEdgeException(500, LambdaError.EdgeNoOrigin)

querystring: str = request["querystring"]
if not querystring:
logger.error(f"Missing query string: {querystring}")
raise CloudFrontEdgeException(500, LambdaError.EdgeNoQuery)

query_params = {
query: value[0] for query, value in parse_qs(querystring).items()
}
missing_query_params = [
param for param in REQUIRED_QUERY_PARAMS if param not in query_params
]
if missing_query_params:
logger.error(f"Missing required query parameters: {missing_query_params}")
raise CloudFrontEdgeException(500, LambdaError.EdgeMissingQuery)

headers = request["headers"]
missing_headers = [
header for header in REQUIRED_HEADERS if header.lower() not in headers
]
if missing_headers:
logger.error(f"Missing required headers: {missing_headers}")
raise CloudFrontEdgeException(500, LambdaError.EdgeMissingHeaders)

return lambda_func(event, context)

return wrapper

0 comments on commit f48058d

Please sign in to comment.