-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move exception of s3 edge to decorator
- Loading branch information
1 parent
2590222
commit f48058d
Showing
4 changed files
with
131 additions
and
99 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,93 +1,29 @@ | ||
import hashlib | ||
import json | ||
import logging | ||
from urllib.parse import parse_qs | ||
|
||
from enums.lambda_error import LambdaError | ||
from services.edge_presign_service import EdgePresignService | ||
from utils.decorators.handle_edge_exceptions import handle_edge_exceptions | ||
from utils.decorators.override_error_check import override_error_check | ||
from utils.decorators.set_audit_arg import set_request_context_for_logging | ||
from utils.lambda_exceptions import CloudFrontEdgeException | ||
|
||
from lambdas.utils.decorators.validate_s3_request import validate_s3_request | ||
|
||
logger = logging.getLogger() | ||
logger.setLevel(logging.INFO) | ||
|
||
REQUIRED_QUERY_PARAMS = [ | ||
"X-Amz-Algorithm", | ||
"X-Amz-Credential", | ||
"X-Amz-Date", | ||
"X-Amz-Expires", | ||
"X-Amz-SignedHeaders", | ||
"X-Amz-Signature", | ||
"X-Amz-Security-Token", | ||
] | ||
|
||
REQUIRED_HEADERS = ["host", "cloudfront-viewer-country", "x-forwarded-for"] | ||
|
||
|
||
@set_request_context_for_logging | ||
@override_error_check | ||
@handle_edge_exceptions | ||
@validate_s3_request | ||
def lambda_handler(event, context): | ||
try: | ||
request: dict = event["Records"][0]["cf"]["request"] | ||
logger.info(f"CloudFront received S3 request {json.dumps(request)}") | ||
uri: str = request["uri"] | ||
querystring: str = request["querystring"] | ||
headers: dict = request["headers"] | ||
try: | ||
origin = request.get("origin", {}) | ||
domain_name = origin["s3"]["domainName"] | ||
except KeyError as e: | ||
logger.error( | ||
f"Missing origin: {str(e)}", | ||
{"Result": LambdaError.EdgeNoOrigin.to_str()}, | ||
) | ||
raise CloudFrontEdgeException(500, LambdaError.EdgeNoOrigin) | ||
|
||
try: | ||
query_params = {k: v[0] for k, v in parse_qs(querystring).items()} | ||
except ValueError: | ||
logger.error(f"Malformed query string: {querystring}") | ||
raise CloudFrontEdgeException(500, LambdaError.EdgeMalformedQuery) | ||
|
||
missing_query_params = [ | ||
param for param in REQUIRED_QUERY_PARAMS if param not in query_params | ||
] | ||
if missing_query_params: | ||
logger.error(f"Missing required query parameters: {missing_query_params}") | ||
raise CloudFrontEdgeException(500, LambdaError.EdgeMalformedQuery) | ||
|
||
missing_headers = [ | ||
header for header in REQUIRED_HEADERS if header.lower() not in headers | ||
] | ||
if missing_headers: | ||
logger.error(f"Missing required headers: {missing_headers}") | ||
raise CloudFrontEdgeException(500, LambdaError.EdgeMalformedHeader) | ||
|
||
presign_string = f"{uri}?{querystring}" | ||
encoded_presign_string: str = presign_string.encode("utf-8") | ||
presign_credentials_hash = hashlib.md5(encoded_presign_string).hexdigest() | ||
|
||
edge_presign_service = EdgePresignService() | ||
edge_presign_service.attempt_url_update( | ||
uri_hash=presign_credentials_hash, | ||
domain_name=domain_name, | ||
) | ||
|
||
if "authorization" in headers: | ||
del headers["authorization"] | ||
request = event["Records"][0]["cf"]["request"] | ||
logger.info("Edge received S3 request") | ||
|
||
request["headers"] = headers | ||
request["headers"]["host"] = [{"key": "Host", "value": domain_name}] | ||
logger.info(f"Edge Response: {json.dumps(request)}") | ||
edge_presign_service = EdgePresignService() | ||
request_values = edge_presign_service.extract_request_values(request) | ||
edge_presign_service.presign_request(request_values) | ||
|
||
return request | ||
request = edge_presign_service.prepare_s3_response(request, request_values) | ||
|
||
except (KeyError, IndexError) as e: | ||
logger.error( | ||
f"Generic Edge Malformed Error: {str(e)}", | ||
{"Result": LambdaError.EdgeMalformed.to_str()}, | ||
) | ||
raise CloudFrontEdgeException(500, LambdaError.EdgeMalformed) | ||
logger.info("Edge returning S3 response") | ||
return request |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,35 @@ | ||
from typing import Callable | ||
|
||
from enums.lambda_error import LambdaError | ||
from utils.audit_logging_setup import LoggingService | ||
from utils.edge_response import EdgeResponse | ||
from utils.error_response import ErrorResponse | ||
from utils.lambda_exceptions import LambdaException | ||
from utils.lambda_exceptions import CloudFrontEdgeException | ||
from utils.request_context import request_context | ||
|
||
logger = LoggingService(__name__) | ||
|
||
|
||
def handle_edge_exceptions(lambda_func: Callable): | ||
"""A decorator for lambda edge handler. | ||
Catch custom Edge Exceptions or AWS ClientError that may be unhandled or raised | ||
Usage: | ||
@handle_edge_exceptions | ||
def lambda_handler(event, context): | ||
... | ||
""" | ||
|
||
def interceptor(event, context): | ||
interaction_id = getattr(request_context, "request_id", None) | ||
try: | ||
return lambda_func(event, context) | ||
except LambdaException as e: | ||
except CloudFrontEdgeException as e: | ||
logger.error(str(e)) | ||
|
||
interaction_id = getattr(request_context, "request_id", None) | ||
return EdgeResponse( | ||
status_code=e.status_code, | ||
body=ErrorResponse(e.err_code, e.message, interaction_id).create(), | ||
methods=event.get("httpMethod", "GET"), | ||
).create_edge_response() | ||
except Exception as e: | ||
logger.error(f"Unhandled exception: {str(e)}") | ||
err_code = LambdaError.EdgeMalformed.value("err_code") | ||
message = LambdaError.EdgeMalformed.value("message") | ||
return EdgeResponse( | ||
status_code=500, | ||
body=ErrorResponse(err_code, message, interaction_id).create(), | ||
methods=event.get("httpMethod", "GET"), | ||
).create_edge_response() | ||
|
||
return interceptor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import logging | ||
from functools import wraps | ||
from urllib.parse import parse_qs | ||
|
||
from enums.lambda_error import LambdaError | ||
from utils.lambda_exceptions import CloudFrontEdgeException | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
REQUIRED_QUERY_PARAMS = [ | ||
"X-Amz-Algorithm", | ||
"X-Amz-Credential", | ||
"X-Amz-Date", | ||
"X-Amz-Expires", | ||
"X-Amz-SignedHeaders", | ||
"X-Amz-Signature", | ||
"X-Amz-Security-Token", | ||
] | ||
|
||
REQUIRED_HEADERS = ["host", "cloudfront-viewer-country", "x-forwarded-for"] | ||
|
||
|
||
def validate_s3_request(lambda_func): | ||
@wraps(lambda_func) | ||
def wrapper(event, context): | ||
request: dict = event["Records"][0]["cf"]["request"] | ||
|
||
if ( | ||
"uri" not in request | ||
or "querystring" not in request | ||
or "headers" not in request | ||
): | ||
logger.error( | ||
"Missing required request components: uri, querystring, or headers." | ||
) | ||
raise CloudFrontEdgeException(500, LambdaError.EdgeMalformed) | ||
|
||
origin: dict = request.get("origin", {}) | ||
if "s3" not in origin or "domainName" not in origin["s3"]: | ||
logger.error("Missing origin domain name.") | ||
raise CloudFrontEdgeException(500, LambdaError.EdgeNoOrigin) | ||
|
||
querystring: str = request["querystring"] | ||
if not querystring: | ||
logger.error(f"Missing query string: {querystring}") | ||
raise CloudFrontEdgeException(500, LambdaError.EdgeNoQuery) | ||
|
||
query_params = { | ||
query: value[0] for query, value in parse_qs(querystring).items() | ||
} | ||
missing_query_params = [ | ||
param for param in REQUIRED_QUERY_PARAMS if param not in query_params | ||
] | ||
if missing_query_params: | ||
logger.error(f"Missing required query parameters: {missing_query_params}") | ||
raise CloudFrontEdgeException(500, LambdaError.EdgeMissingQuery) | ||
|
||
headers = request["headers"] | ||
missing_headers = [ | ||
header for header in REQUIRED_HEADERS if header.lower() not in headers | ||
] | ||
if missing_headers: | ||
logger.error(f"Missing required headers: {missing_headers}") | ||
raise CloudFrontEdgeException(500, LambdaError.EdgeMissingHeaders) | ||
|
||
return lambda_func(event, context) | ||
|
||
return wrapper |