diff --git a/lambdas/handlers/edge_presign_handler.py b/lambdas/handlers/edge_presign_handler.py index bd9c9de74..e3ef81887 100644 --- a/lambdas/handlers/edge_presign_handler.py +++ b/lambdas/handlers/edge_presign_handler.py @@ -1,93 +1,29 @@ -import hashlib -import json import logging -from urllib.parse import parse_qs -from enums.lambda_error import LambdaError from services.edge_presign_service import EdgePresignService from utils.decorators.handle_edge_exceptions import handle_edge_exceptions from utils.decorators.override_error_check import override_error_check from utils.decorators.set_audit_arg import set_request_context_for_logging -from utils.lambda_exceptions import CloudFrontEdgeException + +from lambdas.utils.decorators.validate_s3_request import validate_s3_request logger = logging.getLogger() logger.setLevel(logging.INFO) -REQUIRED_QUERY_PARAMS = [ - "X-Amz-Algorithm", - "X-Amz-Credential", - "X-Amz-Date", - "X-Amz-Expires", - "X-Amz-SignedHeaders", - "X-Amz-Signature", - "X-Amz-Security-Token", -] - -REQUIRED_HEADERS = ["host", "cloudfront-viewer-country", "x-forwarded-for"] - @set_request_context_for_logging @override_error_check @handle_edge_exceptions +@validate_s3_request def lambda_handler(event, context): - try: - request: dict = event["Records"][0]["cf"]["request"] - logger.info(f"CloudFront received S3 request {json.dumps(request)}") - uri: str = request["uri"] - querystring: str = request["querystring"] - headers: dict = request["headers"] - try: - origin = request.get("origin", {}) - domain_name = origin["s3"]["domainName"] - except KeyError as e: - logger.error( - f"Missing origin: {str(e)}", - {"Result": LambdaError.EdgeNoOrigin.to_str()}, - ) - raise CloudFrontEdgeException(500, LambdaError.EdgeNoOrigin) - - try: - query_params = {k: v[0] for k, v in parse_qs(querystring).items()} - except ValueError: - logger.error(f"Malformed query string: {querystring}") - raise CloudFrontEdgeException(500, LambdaError.EdgeMalformedQuery) - - missing_query_params = [ - param for param in REQUIRED_QUERY_PARAMS if param not in query_params - ] - if missing_query_params: - logger.error(f"Missing required query parameters: {missing_query_params}") - raise CloudFrontEdgeException(500, LambdaError.EdgeMalformedQuery) - - missing_headers = [ - header for header in REQUIRED_HEADERS if header.lower() not in headers - ] - if missing_headers: - logger.error(f"Missing required headers: {missing_headers}") - raise CloudFrontEdgeException(500, LambdaError.EdgeMalformedHeader) - - presign_string = f"{uri}?{querystring}" - encoded_presign_string: str = presign_string.encode("utf-8") - presign_credentials_hash = hashlib.md5(encoded_presign_string).hexdigest() - - edge_presign_service = EdgePresignService() - edge_presign_service.attempt_url_update( - uri_hash=presign_credentials_hash, - domain_name=domain_name, - ) - - if "authorization" in headers: - del headers["authorization"] + request = event["Records"][0]["cf"]["request"] + logger.info("Edge received S3 request") - request["headers"] = headers - request["headers"]["host"] = [{"key": "Host", "value": domain_name}] - logger.info(f"Edge Response: {json.dumps(request)}") + edge_presign_service = EdgePresignService() + request_values = edge_presign_service.extract_request_values(request) + edge_presign_service.presign_request(request_values) - return request + request = edge_presign_service.prepare_s3_response(request, request_values) - except (KeyError, IndexError) as e: - logger.error( - f"Generic Edge Malformed Error: {str(e)}", - {"Result": LambdaError.EdgeMalformed.to_str()}, - ) - raise CloudFrontEdgeException(500, LambdaError.EdgeMalformed) + logger.info("Edge returning S3 response") + return request diff --git a/lambdas/services/edge_presign_service.py b/lambdas/services/edge_presign_service.py index 2f05cd076..6661dfa00 100644 --- a/lambdas/services/edge_presign_service.py +++ b/lambdas/services/edge_presign_service.py @@ -1,4 +1,4 @@ -import re +import hashlib from botocore.exceptions import ClientError from enums.lambda_error import LambdaError @@ -12,7 +12,6 @@ class EdgePresignService: - def __init__(self): self.dynamo_service = DynamoDBService() self.s3_service = S3Service() @@ -41,15 +40,44 @@ def attempt_url_update(self, uri_hash, domain_name) -> None: logger.error(f"{str(e)}", {"Result": LambdaError.EdgeNoClient.to_str()}) raise CloudFrontEdgeException(400, LambdaError.EdgeNoClient) + def presign_request(self, request_values): + uri = request_values["uri"] + querystring = request_values["querystring"] + domain_name = request_values["domain_name"] + + presign_string = f"{uri}?{querystring}" + encoded_presign_string = presign_string.encode("utf-8") + presign_credentials_hash = hashlib.md5(encoded_presign_string).hexdigest() + + self.attempt_url_update( + uri_hash=presign_credentials_hash, + domain_name=domain_name, + ) + @staticmethod - def extract_environment_from_domain(domain_name: str) -> str: - match = re.match(r"^[^-]+(?:-[^-]+)?(?=-lloyd)", domain_name) - if match: - return match.group(0) - return "" + def prepare_s3_response(request, request_values): + domain_name = request_values["domain_name"] + if "authorization" in request["headers"]: + del request["headers"]["authorization"] + request["headers"]["host"] = [{"key": "Host", "value": domain_name}] + + return request @staticmethod - def extend_table_name(base_table_name, environment) -> str: - if environment: - return f"{environment}_{base_table_name}" - return base_table_name + def extract_request_values(request) -> dict: + try: + uri = request["uri"] + querystring = request["querystring"] + headers = request["headers"] + origin = request.get("origin", {}) + domain_name = origin["s3"]["domainName"] + except KeyError as e: + logger.error(f"Missing request component: {str(e)}") + raise CloudFrontEdgeException(500, LambdaError.EdgeNoOrigin) + + return { + "uri": uri, + "querystring": querystring, + "headers": headers, + "domain_name": domain_name, + } diff --git a/lambdas/utils/decorators/handle_edge_exceptions.py b/lambdas/utils/decorators/handle_edge_exceptions.py index 468f71512..49c42ad0a 100644 --- a/lambdas/utils/decorators/handle_edge_exceptions.py +++ b/lambdas/utils/decorators/handle_edge_exceptions.py @@ -1,35 +1,35 @@ from typing import Callable +from enums.lambda_error import LambdaError from utils.audit_logging_setup import LoggingService from utils.edge_response import EdgeResponse from utils.error_response import ErrorResponse -from utils.lambda_exceptions import LambdaException +from utils.lambda_exceptions import CloudFrontEdgeException from utils.request_context import request_context logger = LoggingService(__name__) def handle_edge_exceptions(lambda_func: Callable): - """A decorator for lambda edge handler. - Catch custom Edge Exceptions or AWS ClientError that may be unhandled or raised - - Usage: - @handle_edge_exceptions - def lambda_handler(event, context): - ... - """ - def interceptor(event, context): + interaction_id = getattr(request_context, "request_id", None) try: return lambda_func(event, context) - except LambdaException as e: + except CloudFrontEdgeException as e: logger.error(str(e)) - - interaction_id = getattr(request_context, "request_id", None) return EdgeResponse( status_code=e.status_code, body=ErrorResponse(e.err_code, e.message, interaction_id).create(), methods=event.get("httpMethod", "GET"), ).create_edge_response() + except Exception as e: + logger.error(f"Unhandled exception: {str(e)}") + err_code = LambdaError.EdgeMalformed.value("err_code") + message = LambdaError.EdgeMalformed.value("message") + return EdgeResponse( + status_code=500, + body=ErrorResponse(err_code, message, interaction_id).create(), + methods=event.get("httpMethod", "GET"), + ).create_edge_response() return interceptor diff --git a/lambdas/utils/decorators/validate_s3_request.py b/lambdas/utils/decorators/validate_s3_request.py new file mode 100644 index 000000000..20cf93c3b --- /dev/null +++ b/lambdas/utils/decorators/validate_s3_request.py @@ -0,0 +1,68 @@ +import logging +from functools import wraps +from urllib.parse import parse_qs + +from enums.lambda_error import LambdaError +from utils.lambda_exceptions import CloudFrontEdgeException + +logger = logging.getLogger(__name__) + +REQUIRED_QUERY_PARAMS = [ + "X-Amz-Algorithm", + "X-Amz-Credential", + "X-Amz-Date", + "X-Amz-Expires", + "X-Amz-SignedHeaders", + "X-Amz-Signature", + "X-Amz-Security-Token", +] + +REQUIRED_HEADERS = ["host", "cloudfront-viewer-country", "x-forwarded-for"] + + +def validate_s3_request(lambda_func): + @wraps(lambda_func) + def wrapper(event, context): + request: dict = event["Records"][0]["cf"]["request"] + + if ( + "uri" not in request + or "querystring" not in request + or "headers" not in request + ): + logger.error( + "Missing required request components: uri, querystring, or headers." + ) + raise CloudFrontEdgeException(500, LambdaError.EdgeMalformed) + + origin: dict = request.get("origin", {}) + if "s3" not in origin or "domainName" not in origin["s3"]: + logger.error("Missing origin domain name.") + raise CloudFrontEdgeException(500, LambdaError.EdgeNoOrigin) + + querystring: str = request["querystring"] + if not querystring: + logger.error(f"Missing query string: {querystring}") + raise CloudFrontEdgeException(500, LambdaError.EdgeNoQuery) + + query_params = { + query: value[0] for query, value in parse_qs(querystring).items() + } + missing_query_params = [ + param for param in REQUIRED_QUERY_PARAMS if param not in query_params + ] + if missing_query_params: + logger.error(f"Missing required query parameters: {missing_query_params}") + raise CloudFrontEdgeException(500, LambdaError.EdgeMissingQuery) + + headers = request["headers"] + missing_headers = [ + header for header in REQUIRED_HEADERS if header.lower() not in headers + ] + if missing_headers: + logger.error(f"Missing required headers: {missing_headers}") + raise CloudFrontEdgeException(500, LambdaError.EdgeMissingHeaders) + + return lambda_func(event, context) + + return wrapper