From 172a6bb58e8cb25cdf7f49f1f4ab3a8ead5ef0a5 Mon Sep 17 00:00:00 2001 From: NogaNHS Date: Wed, 30 Oct 2024 16:39:54 +0000 Subject: [PATCH] [prmp-1120] new nrl service --- lambdas/services/base/nhs_oauth_service.py | 110 ++++++ lambdas/services/nrl_api_service.py | 30 ++ lambdas/services/pds_api_service.py | 97 +---- lambdas/tests/unit/helpers/mock_services.py | 3 + .../services/base/test_nhs_oauth_service.py | 275 ++++++++++++++ .../unit/services/test_pds_api_service.py | 355 ++---------------- lambdas/utils/exceptions.py | 4 + 7 files changed, 456 insertions(+), 418 deletions(-) create mode 100644 lambdas/services/base/nhs_oauth_service.py create mode 100644 lambdas/services/nrl_api_service.py create mode 100644 lambdas/tests/unit/services/base/test_nhs_oauth_service.py diff --git a/lambdas/services/base/nhs_oauth_service.py b/lambdas/services/base/nhs_oauth_service.py new file mode 100644 index 000000000..139a3eaff --- /dev/null +++ b/lambdas/services/base/nhs_oauth_service.py @@ -0,0 +1,110 @@ +import json +import time +import uuid + +import jwt +import requests +from enums.pds_ssm_parameters import SSMParameter +from requests.exceptions import HTTPError +from utils.audit_logging_setup import LoggingService +from utils.exceptions import OAuthErrorException + +logger = LoggingService(__name__) + + +class NhsOauthService: + def __init__(self, ssm_service): + self.ssm_service = ssm_service + + def create_access_token(self): + access_token_response = self.get_current_access_token() + access_token_response = json.loads(access_token_response) + access_token = access_token_response["access_token"] + access_token_expiration = ( + int(access_token_response["expires_in"]) + + int(access_token_response["issued_at"]) / 1000 + ) + time_safety_margin_seconds = 10 + remaining_time_before_expiration = access_token_expiration - time.time() + if remaining_time_before_expiration < time_safety_margin_seconds: + access_token = self.get_new_access_token() + + return access_token + + def get_new_access_token(self): + logger.info("Getting new OAuth access token") + try: + access_token_ssm_parameter = self.get_parameters_for_new_access_token() + jwt_token = self.create_jwt_token_for_new_access_token_request( + access_token_ssm_parameter + ) + nhs_oauth_endpoint = access_token_ssm_parameter[ + SSMParameter.NHS_OAUTH_ENDPOINT.value + ] + nhs_oauth_response = self.request_new_access_token( + jwt_token, nhs_oauth_endpoint + ) + nhs_oauth_response.raise_for_status() + token_access_response = nhs_oauth_response.json() + self.update_access_token_ssm(json.dumps(token_access_response)) + except HTTPError as e: + logger.error( + e.response, {"Result": "Issue while creating new access token"} + ) + raise OAuthErrorException("Error creating oauth access token") + return token_access_response["access_token"] + + def get_parameters_for_new_access_token(self): + parameters = [ + SSMParameter.NHS_OAUTH_ENDPOINT.value, + SSMParameter.PDS_KID.value, + SSMParameter.NHS_OAUTH_KEY.value, + SSMParameter.PDS_API_KEY.value, + ] + return self.ssm_service.get_ssm_parameters(parameters, with_decryption=True) + + def update_access_token_ssm(self, parameter_value: str): + parameter_key = SSMParameter.PDS_API_ACCESS_TOKEN.value + self.ssm_service.update_ssm_parameter( + parameter_key=parameter_key, + parameter_value=parameter_value, + parameter_type="SecureString", + ) + + def get_current_access_token(self): + parameters = [ + SSMParameter.PDS_API_ACCESS_TOKEN.value, + ] + ssm_response = self.ssm_service.get_ssm_parameter( + parameters_keys=parameters, with_decryption=True + ) + return ssm_response + + def create_jwt_token_for_new_access_token_request( + self, access_token_ssm_parameters + ): + nhs_oauth_endpoint = access_token_ssm_parameters[ + SSMParameter.NHS_OAUTH_ENDPOINT.value + ] + kid = access_token_ssm_parameters[SSMParameter.PDS_KID.value] + nhs_key = access_token_ssm_parameters[SSMParameter.NHS_OAUTH_KEY.value] + pds_key = access_token_ssm_parameters[SSMParameter.PDS_API_KEY.value] + payload = { + "iss": nhs_key, + "sub": nhs_key, + "aud": nhs_oauth_endpoint, + "jti": str(uuid.uuid4()), + "exp": int(time.time()) + 300, + } + return jwt.encode(payload, pds_key, algorithm="RS512", headers={"kid": kid}) + + def request_new_access_token(self, jwt_token, nhs_oauth_endpoint): + access_token_headers = {"content-type": "application/x-www-form-urlencoded"} + access_token_data = { + "grant_type": "client_credentials", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_assertion": jwt_token, + } + return requests.post( + url=nhs_oauth_endpoint, headers=access_token_headers, data=access_token_data + ) diff --git a/lambdas/services/nrl_api_service.py b/lambdas/services/nrl_api_service.py new file mode 100644 index 000000000..4d4d5d51c --- /dev/null +++ b/lambdas/services/nrl_api_service.py @@ -0,0 +1,30 @@ +import uuid + +from services.base.nhs_oauth_service import NhsOauthService +from utils.audit_logging_setup import LoggingService + +logger = LoggingService(__name__) + + +class NrlApiService(NhsOauthService): + def __init__(self, ssm_service): + super().__init__(ssm_service) + self.headers = { + "Authorization": f"Bearer {self.create_access_token()}", + "Accept": "application/json", + } + + def get_api_endpoint(self): + pass + + def create_new_pointer(self, body, headers): + self.set_x_request_id() + + def update_pointer(self): + self.set_x_request_id() + + def delete_pointer(self): + self.set_x_request_id() + + def set_x_request_id(self): + self.headers["X-Request-ID"] = uuid.uuid4() diff --git a/lambdas/services/pds_api_service.py b/lambdas/services/pds_api_service.py index 4eb728374..619867d8a 100644 --- a/lambdas/services/pds_api_service.py +++ b/lambdas/services/pds_api_service.py @@ -1,14 +1,12 @@ -import json -import time import uuid from json import JSONDecodeError -import jwt import requests from botocore.exceptions import ClientError from enums.pds_ssm_parameters import SSMParameter from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, HTTPError, Timeout +from services.base.nhs_oauth_service import NhsOauthService from services.patient_search_service import PatientSearch from urllib3 import Retry from utils.audit_logging_setup import LoggingService @@ -17,9 +15,9 @@ logger = LoggingService(__name__) -class PdsApiService(PatientSearch): +class PdsApiService(PatientSearch, NhsOauthService): def __init__(self, ssm_service): - self.ssm_service = ssm_service + super().__init__(ssm_service) retry_strategy = Retry( total=3, @@ -34,17 +32,8 @@ def __init__(self, ssm_service): def pds_request(self, nhs_number: str, retry_on_expired: bool): try: - endpoint, access_token_response = self.get_parameters_for_pds_api_request() - access_token_response = json.loads(access_token_response) - access_token = access_token_response["access_token"] - access_token_expiration = ( - int(access_token_response["expires_in"]) - + int(access_token_response["issued_at"]) / 1000 - ) - time_safety_margin_seconds = 10 - remaining_time_before_expiration = access_token_expiration - time.time() - if remaining_time_before_expiration < time_safety_margin_seconds: - access_token = self.get_new_access_token() + endpoint = self.get_endpoint_for_pds_api_request() + access_token = self.create_access_token() x_request_id = str(uuid.uuid4()) @@ -71,81 +60,11 @@ def pds_request(self, nhs_number: str, retry_on_expired: bool): logger.error(str(e), {"Result": "Error when calling PDS"}) raise PdsTooManyRequestsException("Failed to perform patient search") - def get_new_access_token(self): - logger.info("Getting new PDS access token") - try: - access_token_ssm_parameter = self.get_parameters_for_new_access_token() - jwt_token = self.create_jwt_token_for_new_access_token_request( - access_token_ssm_parameter - ) - nhs_oauth_endpoint = access_token_ssm_parameter[ - SSMParameter.NHS_OAUTH_ENDPOINT.value - ] - nhs_oauth_response = self.request_new_access_token( - jwt_token, nhs_oauth_endpoint - ) - nhs_oauth_response.raise_for_status() - token_access_response = nhs_oauth_response.json() - self.update_access_token_ssm(json.dumps(token_access_response)) - except HTTPError as e: - logger.error( - e.response, {"Result": "Issue while creating new access token"} - ) - raise PdsErrorException("Error accessing PDS API") - return token_access_response["access_token"] - - def get_parameters_for_new_access_token(self): - parameters = [ - SSMParameter.NHS_OAUTH_ENDPOINT.value, - SSMParameter.PDS_KID.value, - SSMParameter.NHS_OAUTH_KEY.value, - SSMParameter.PDS_API_KEY.value, - ] - return self.ssm_service.get_ssm_parameters(parameters, with_decryption=True) - - def update_access_token_ssm(self, parameter_value: str): - parameter_key = SSMParameter.PDS_API_ACCESS_TOKEN.value - self.ssm_service.update_ssm_parameter( - parameter_key=parameter_key, - parameter_value=parameter_value, - parameter_type="SecureString", - ) - - def get_parameters_for_pds_api_request(self): + def get_endpoint_for_pds_api_request(self): parameters = [ SSMParameter.PDS_API_ENDPOINT.value, - SSMParameter.PDS_API_ACCESS_TOKEN.value, ] - ssm_response = self.ssm_service.get_ssm_parameters( + ssm_response = self.ssm_service.get_ssm_parameter( parameters_keys=parameters, with_decryption=True ) - return ssm_response[parameters[0]], ssm_response[parameters[1]] - - def create_jwt_token_for_new_access_token_request( - self, access_token_ssm_parameters - ): - nhs_oauth_endpoint = access_token_ssm_parameters[ - SSMParameter.NHS_OAUTH_ENDPOINT.value - ] - kid = access_token_ssm_parameters[SSMParameter.PDS_KID.value] - nhs_key = access_token_ssm_parameters[SSMParameter.NHS_OAUTH_KEY.value] - pds_key = access_token_ssm_parameters[SSMParameter.PDS_API_KEY.value] - payload = { - "iss": nhs_key, - "sub": nhs_key, - "aud": nhs_oauth_endpoint, - "jti": str(uuid.uuid4()), - "exp": int(time.time()) + 300, - } - return jwt.encode(payload, pds_key, algorithm="RS512", headers={"kid": kid}) - - def request_new_access_token(self, jwt_token, nhs_oauth_endpoint): - access_token_headers = {"content-type": "application/x-www-form-urlencoded"} - access_token_data = { - "grant_type": "client_credentials", - "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", - "client_assertion": jwt_token, - } - return requests.post( - url=nhs_oauth_endpoint, headers=access_token_headers, data=access_token_data - ) + return ssm_response diff --git a/lambdas/tests/unit/helpers/mock_services.py b/lambdas/tests/unit/helpers/mock_services.py index 3c5c13f42..55e78f39f 100644 --- a/lambdas/tests/unit/helpers/mock_services.py +++ b/lambdas/tests/unit/helpers/mock_services.py @@ -5,6 +5,9 @@ def __init__(self, *arg, **kwargs): def get_ssm_parameters(self, parameters_keys, *arg, **kwargs): return {parameter: f"test_value_{parameter}" for parameter in parameters_keys} + def get_ssm_parameter(self, parameters_keys, *arg, **kwargs): + return f"test_value_{parameters_keys[0]}" + def update_ssm_parameter(self, *arg, **kwargs): pass diff --git a/lambdas/tests/unit/services/base/test_nhs_oauth_service.py b/lambdas/tests/unit/services/base/test_nhs_oauth_service.py new file mode 100644 index 000000000..03deefda4 --- /dev/null +++ b/lambdas/tests/unit/services/base/test_nhs_oauth_service.py @@ -0,0 +1,275 @@ +import json + +import pytest +from enums.pds_ssm_parameters import SSMParameter +from requests import Response +from services.base.nhs_oauth_service import NhsOauthService +from tests.unit.helpers.data.pds.access_token_response import RESPONSE_TOKEN +from tests.unit.helpers.mock_services import FakeSSMService +from utils.exceptions import OAuthErrorException + +fake_ssm_service = FakeSSMService() +nhs_oauth_service = NhsOauthService(fake_ssm_service) + + +def mock_pds_token_response_issued_at(timestamp_in_sec: float) -> dict: + response_token = { + "access_token": "Sr5PGv19wTEHJdDr2wx2f7IGd0cw", + "expires_in": "599", + "token_type": "Bearer", + "issued_at": str(int(timestamp_in_sec * 1000)), + } + + return response_token + + +def test_request_new_token_is_call_with_correct_data(mocker): + mock_jwt_token = "testtest" + mock_endpoint = "api.endpoint/mock" + access_token_headers = {"content-type": "application/x-www-form-urlencoded"} + access_token_data = { + "grant_type": "client_credentials", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_assertion": mock_jwt_token, + } + mock_post = mocker.patch("requests.post") + nhs_oauth_service.request_new_access_token(mock_jwt_token, mock_endpoint) + mock_post.assert_called_with( + url=mock_endpoint, headers=access_token_headers, data=access_token_data + ) + + +def test_create_jwt_for_new_access_token(mocker): + access_token_parameters = { + SSMParameter.NHS_OAUTH_ENDPOINT.value: "api.endpoint/mock", + SSMParameter.PDS_KID.value: "test_string_pds_kid", + SSMParameter.NHS_OAUTH_KEY.value: "test_string_key_oauth", + SSMParameter.PDS_API_KEY.value: "test_string_key_pds", + } + expected_payload = { + "iss": "test_string_key_oauth", + "sub": "test_string_key_oauth", + "aud": "api.endpoint/mock", + "jti": "123412342", + "exp": 1534, + } + mocker.patch("time.time", return_value=1234.1) + mocker.patch("uuid.uuid4", return_value="123412342") + + mock_jwt_encode = mocker.patch("jwt.encode") + nhs_oauth_service.create_jwt_token_for_new_access_token_request( + access_token_parameters + ) + mock_jwt_encode.assert_called_with( + expected_payload, + "test_string_key_pds", + algorithm="RS512", + headers={"kid": "test_string_pds_kid"}, + ) + + +def test_get_current_access_token(): + ssm_parameters_expected = f"test_value_{SSMParameter.PDS_API_ACCESS_TOKEN.value}" + + actual = nhs_oauth_service.get_current_access_token() + assert ssm_parameters_expected == actual + + +def test_update_access_token_ssm(mocker): + fake_ssm_service.update_ssm_parameter = mocker.MagicMock() + + nhs_oauth_service.update_access_token_ssm("test_string") + + fake_ssm_service.update_ssm_parameter.assert_called_with( + parameter_key=SSMParameter.PDS_API_ACCESS_TOKEN.value, + parameter_value="test_string", + parameter_type="SecureString", + ) + + +def test_get_parameters_for_new_access_token(mocker): + parameters = [ + SSMParameter.NHS_OAUTH_ENDPOINT.value, + SSMParameter.PDS_KID.value, + SSMParameter.NHS_OAUTH_KEY.value, + SSMParameter.PDS_API_KEY.value, + ] + fake_ssm_service.get_ssm_parameters = mocker.MagicMock() + nhs_oauth_service.get_parameters_for_new_access_token() + fake_ssm_service.get_ssm_parameters.assert_called_with( + parameters, with_decryption=True + ) + + +def test_get_new_access_token_raise_OAuthErrorException(mocker): + with pytest.raises(OAuthErrorException): + response = Response() + response.status_code = 400 + mock_nhs_oauth_endpoint = "api.test/endpoint" + mock_token = "test_token" + mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.get_parameters_for_new_access_token", + return_value={ + SSMParameter.NHS_OAUTH_ENDPOINT.value: mock_nhs_oauth_endpoint + }, + ) + mock_create_jwt = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.create_jwt_token_for_new_access_token_request", + return_value=mock_token, + ) + mock_api_call_oauth = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.request_new_access_token", + return_value=response, + ) + mock_update_ssm = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.update_access_token_ssm" + ) + + nhs_oauth_service.get_new_access_token() + + mock_create_jwt.assert_called_with( + {SSMParameter.NHS_OAUTH_ENDPOINT.value: mock_nhs_oauth_endpoint} + ) + mock_api_call_oauth.assert_called_with(mock_token, mock_nhs_oauth_endpoint) + mock_update_ssm.assert_not_called() + + +def test_get_new_access_token_return_200(mocker): + response = Response() + response.status_code = 200 + response._content = json.dumps(RESPONSE_TOKEN).encode("utf-8") + mock_nhs_oauth_endpoint = "api.test/endpoint" + mock_token = "test_token" + mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.get_parameters_for_new_access_token", + return_value={SSMParameter.NHS_OAUTH_ENDPOINT.value: mock_nhs_oauth_endpoint}, + ) + mock_create_jwt = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.create_jwt_token_for_new_access_token_request", + return_value=mock_token, + ) + mock_api_call_oauth = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.request_new_access_token", + return_value=response, + ) + mock_update_ssm = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.update_access_token_ssm" + ) + expected = RESPONSE_TOKEN["access_token"] + + actual = nhs_oauth_service.get_new_access_token() + + mock_create_jwt.assert_called_with( + {SSMParameter.NHS_OAUTH_ENDPOINT.value: mock_nhs_oauth_endpoint} + ) + mock_api_call_oauth.assert_called_with(mock_token, mock_nhs_oauth_endpoint) + mock_update_ssm.assert_called_with(json.dumps(RESPONSE_TOKEN)) + assert expected == actual + + +def test_pds_request_not_refresh_token_if_more_than_10_seconds_before_expiry(mocker): + time_now = 1600000000 + mocker.patch("time.time", return_value=time_now) + mock_response_token = mock_pds_token_response_issued_at(time_now - 599 + 11) + + mock_get_parameters = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.get_current_access_token", + return_value=json.dumps(mock_response_token), + ) + mock_new_access_token = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.get_new_access_token" + ) + mocker.patch("uuid.uuid4", return_value="123412342") + + nhs_oauth_service.create_access_token() + + mock_get_parameters.assert_called_once() + mock_new_access_token.assert_not_called() + + +def test_pds_request_refresh_token_9_seconds_before_expiration( + mocker, +): + time_now = 1600000000 + mocker.patch("time.time", return_value=time_now) + mock_response_token = mock_pds_token_response_issued_at(time_now - 599 + 9) + new_mock_access_token = "mock_access_token" + + mock_get_parameters = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.get_current_access_token", + return_value=json.dumps(mock_response_token), + ) + mock_new_access_token = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.get_new_access_token", + return_value=new_mock_access_token, + ) + mocker.patch("uuid.uuid4", return_value="123412342") + + nhs_oauth_service.create_access_token() + + mock_get_parameters.assert_called_once() + mock_new_access_token.assert_called_once() + + +def test_pds_request_refresh_token_if_already_expired(mocker): + time_now = 1600000000 + mocker.patch("time.time", return_value=time_now) + mock_response_token = mock_pds_token_response_issued_at(time_now - 599) + new_mock_access_token = "mock_access_token" + + mock_get_parameters = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.get_current_access_token", + return_value=json.dumps(mock_response_token), + ) + mock_new_access_token = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.get_new_access_token", + return_value=new_mock_access_token, + ) + mocker.patch("uuid.uuid4", return_value="123412342") + + nhs_oauth_service.create_access_token() + + mock_get_parameters.assert_called_once() + mock_new_access_token.assert_called_once() + + +def test_pds_request_refresh_token_if_already_expired_11_seconds_ago(mocker): + time_now = 1600000000 + mocker.patch("time.time", return_value=time_now) + mock_response_token = mock_pds_token_response_issued_at(time_now - 599 - 11) + new_mock_access_token = "mock_access_token" + + mock_get_parameters = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.get_current_access_token", + return_value=json.dumps(mock_response_token), + ) + mock_new_access_token = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.get_new_access_token", + return_value=new_mock_access_token, + ) + mocker.patch("uuid.uuid4", return_value="123412342") + + nhs_oauth_service.create_access_token() + + mock_get_parameters.assert_called_once() + mock_new_access_token.assert_called_once() + + +def test_pds_request_expired_token(mocker): + new_mock_access_token = "mock_access_token" + + mock_get_parameters = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.get_current_access_token", + return_value=json.dumps(RESPONSE_TOKEN), + ) + mocker.patch("time.time", return_value=1700000000.953031) + mock_new_access_token = mocker.patch( + "services.base.nhs_oauth_service.NhsOauthService.get_new_access_token", + return_value=new_mock_access_token, + ) + mocker.patch("uuid.uuid4", return_value="123412342") + + actual = nhs_oauth_service.create_access_token() + assert actual == new_mock_access_token + mock_get_parameters.assert_called_once() + mock_new_access_token.assert_called_once() diff --git a/lambdas/tests/unit/services/test_pds_api_service.py b/lambdas/tests/unit/services/test_pds_api_service.py index 213eb1789..699ad6cde 100644 --- a/lambdas/tests/unit/services/test_pds_api_service.py +++ b/lambdas/tests/unit/services/test_pds_api_service.py @@ -5,21 +5,11 @@ from enums.pds_ssm_parameters import SSMParameter from requests import Response from services.pds_api_service import PdsApiService -from tests.unit.helpers.data.pds.access_token_response import RESPONSE_TOKEN from tests.unit.helpers.data.pds.pds_patient_response import PDS_PATIENT +from tests.unit.helpers.mock_services import FakeSSMService from utils.exceptions import PdsErrorException - -class FakeSSMService: - def __init__(self, *arg, **kwargs): - pass - - def get_ssm_parameters(self, parameters_keys, *arg, **kwargs): - return {parameter: f"test_value_{parameter}" for parameter in parameters_keys} - - def update_ssm_parameter(self, *arg, **kwargs): - pass - +ACCESS_TOKEN = "Sr5PGv19wTEHJdDr2wx2f7IGd0cw" fake_ssm_service = FakeSSMService() pds_service = PdsApiService(fake_ssm_service) @@ -37,336 +27,45 @@ def mock_get_patient_data(mocker): yield mock_session -def test_request_new_token_is_call_with_correct_data(mocker): - mock_jwt_token = "testtest" - mock_endpoint = "api.endpoint/mock" - access_token_headers = {"content-type": "application/x-www-form-urlencoded"} - access_token_data = { - "grant_type": "client_credentials", - "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", - "client_assertion": mock_jwt_token, - } - mock_post = mocker.patch("requests.post") - pds_service.request_new_access_token(mock_jwt_token, mock_endpoint) - mock_post.assert_called_with( - url=mock_endpoint, headers=access_token_headers, data=access_token_data - ) - - -def test_create_jwt_for_new_access_token(mocker): - access_token_parameters = { - SSMParameter.NHS_OAUTH_ENDPOINT.value: "api.endpoint/mock", - SSMParameter.PDS_KID.value: "test_string_pds_kid", - SSMParameter.NHS_OAUTH_KEY.value: "test_string_key_oauth", - SSMParameter.PDS_API_KEY.value: "test_string_key_pds", - } - expected_payload = { - "iss": "test_string_key_oauth", - "sub": "test_string_key_oauth", - "aud": "api.endpoint/mock", - "jti": "123412342", - "exp": 1534, - } - mocker.patch("time.time", return_value=1234.1) - mocker.patch("uuid.uuid4", return_value="123412342") - - mock_jwt_encode = mocker.patch("jwt.encode") - pds_service.create_jwt_token_for_new_access_token_request(access_token_parameters) - mock_jwt_encode.assert_called_with( - expected_payload, - "test_string_key_pds", - algorithm="RS512", - headers={"kid": "test_string_pds_kid"}, - ) - - def test_get_parameters_for_pds_api_request(): - ssm_parameters_expected = ( - f"test_value_{SSMParameter.PDS_API_ENDPOINT.value}", - f"test_value_{SSMParameter.PDS_API_ACCESS_TOKEN.value}", - ) - actual = pds_service.get_parameters_for_pds_api_request() - assert ssm_parameters_expected == actual - - -def test_update_access_token_ssm(mocker): - fake_ssm_service.update_ssm_parameter = mocker.MagicMock() - - pds_service.update_access_token_ssm("test_string") + ssm_parameters_expected = f"test_value_{SSMParameter.PDS_API_ENDPOINT.value}" - fake_ssm_service.update_ssm_parameter.assert_called_with( - parameter_key=SSMParameter.PDS_API_ACCESS_TOKEN.value, - parameter_value="test_string", - parameter_type="SecureString", - ) - - -def test_get_parameters_for_new_access_token(mocker): - parameters = [ - SSMParameter.NHS_OAUTH_ENDPOINT.value, - SSMParameter.PDS_KID.value, - SSMParameter.NHS_OAUTH_KEY.value, - SSMParameter.PDS_API_KEY.value, - ] - fake_ssm_service.get_ssm_parameters = mocker.MagicMock() - pds_service.get_parameters_for_new_access_token() - fake_ssm_service.get_ssm_parameters.assert_called_with( - parameters, with_decryption=True - ) - - -def test_get_new_access_token_return_200(mocker): - response = Response() - response.status_code = 200 - response._content = json.dumps(RESPONSE_TOKEN).encode("utf-8") - mock_nhs_oauth_endpoint = "api.test/endpoint" - mock_token = "test_token" - mocker.patch( - "services.pds_api_service.PdsApiService.get_parameters_for_new_access_token", - return_value={SSMParameter.NHS_OAUTH_ENDPOINT.value: mock_nhs_oauth_endpoint}, - ) - mock_create_jwt = mocker.patch( - "services.pds_api_service.PdsApiService.create_jwt_token_for_new_access_token_request", - return_value=mock_token, - ) - mock_api_call_oauth = mocker.patch( - "services.pds_api_service.PdsApiService.request_new_access_token", - return_value=response, - ) - mock_update_ssm = mocker.patch( - "services.pds_api_service.PdsApiService.update_access_token_ssm" - ) - expected = RESPONSE_TOKEN["access_token"] - - actual = pds_service.get_new_access_token() - - mock_create_jwt.assert_called_with( - {SSMParameter.NHS_OAUTH_ENDPOINT.value: mock_nhs_oauth_endpoint} - ) - mock_api_call_oauth.assert_called_with(mock_token, mock_nhs_oauth_endpoint) - mock_update_ssm.assert_called_with(json.dumps(RESPONSE_TOKEN)) - assert expected == actual - - -def test_get_new_access_token_raise_PdsErrorException(mocker): - with pytest.raises(PdsErrorException): - response = Response() - response.status_code = 400 - mock_nhs_oauth_endpoint = "api.test/endpoint" - mock_token = "test_token" - mocker.patch( - "services.pds_api_service.PdsApiService.get_parameters_for_new_access_token", - return_value={ - SSMParameter.NHS_OAUTH_ENDPOINT.value: mock_nhs_oauth_endpoint - }, - ) - mock_create_jwt = mocker.patch( - "services.pds_api_service.PdsApiService.create_jwt_token_for_new_access_token_request", - return_value=mock_token, - ) - mock_api_call_oauth = mocker.patch( - "services.pds_api_service.PdsApiService.request_new_access_token", - return_value=response, - ) - mock_update_ssm = mocker.patch( - "services.pds_api_service.PdsApiService.update_access_token_ssm" - ) - - pds_service.get_new_access_token() - - mock_create_jwt.assert_called_with( - {SSMParameter.NHS_OAUTH_ENDPOINT.value: mock_nhs_oauth_endpoint} - ) - mock_api_call_oauth.assert_called_with(mock_token, mock_nhs_oauth_endpoint) - mock_update_ssm.assert_not_called() - - -def mock_pds_token_response_issued_at(timestamp_in_sec: float) -> dict: - response_token = { - "access_token": "Sr5PGv19wTEHJdDr2wx2f7IGd0cw", - "expires_in": "599", - "token_type": "Bearer", - "issued_at": str(int(timestamp_in_sec * 1000)), - } - - return response_token + actual = pds_service.get_endpoint_for_pds_api_request() + assert ssm_parameters_expected == actual def test_pds_request_valid_token(mocker, mock_get_patient_data): time_now = 1600000000 mocker.patch("time.time", return_value=time_now) - mock_response_token = mock_pds_token_response_issued_at(time_now) - - mock_api_request_parameters = ( - "api.test/endpoint/", - json.dumps(mock_response_token), - ) - nhs_number = "1111111111" - mock_url_endpoint = "api.test/endpoint/Patient/" + nhs_number - mock_authorization_header = { - "Authorization": f"Bearer {mock_response_token['access_token']}", - "X-Request-ID": "123412342", - } - - mock_get_parameters = mocker.patch( - "services.pds_api_service.PdsApiService.get_parameters_for_pds_api_request", - return_value=mock_api_request_parameters, - ) - mock_new_access_token = mocker.patch( - "services.pds_api_service.PdsApiService.get_new_access_token" - ) - mocker.patch("uuid.uuid4", return_value="123412342") - - pds_service.pds_request(nhs_number="1111111111", retry_on_expired=True) - - mock_get_parameters.assert_called_once() - mock_new_access_token.assert_not_called() - mock_get_patient_data.get.assert_called_with( - url=mock_url_endpoint, headers=mock_authorization_header - ) - - -def test_pds_request_not_refresh_token_if_more_than_10_seconds_before_expiry( - mocker, mock_get_patient_data -): - time_now = 1600000000 - mocker.patch("time.time", return_value=time_now) - mock_response_token = mock_pds_token_response_issued_at(time_now - 599 + 11) - - mock_api_request_parameters = ( - "api.test/endpoint/", - json.dumps(mock_response_token), - ) - - mock_get_parameters = mocker.patch( - "services.pds_api_service.PdsApiService.get_parameters_for_pds_api_request", - return_value=mock_api_request_parameters, - ) - mock_new_access_token = mocker.patch( - "services.pds_api_service.PdsApiService.get_new_access_token" - ) - mocker.patch("uuid.uuid4", return_value="123412342") - - pds_service.pds_request(nhs_number="1111111111", retry_on_expired=True) - - mock_get_parameters.assert_called_once() - mock_new_access_token.assert_not_called() - - -def test_pds_request_refresh_token_9_seconds_before_expiration( - mocker, mock_get_patient_data -): - time_now = 1600000000 - mocker.patch("time.time", return_value=time_now) - mock_response_token = mock_pds_token_response_issued_at(time_now - 599 + 9) - new_mock_access_token = "mock_access_token" + mock_response_token = ACCESS_TOKEN - mock_api_request_parameters = ( - "api.test/endpoint/", - json.dumps(mock_response_token), - ) nhs_number = "1111111111" mock_url_endpoint = "api.test/endpoint/Patient/" + nhs_number mock_authorization_header = { - "Authorization": f"Bearer {new_mock_access_token}", + "Authorization": f"Bearer {mock_response_token}", "X-Request-ID": "123412342", } mock_get_parameters = mocker.patch( - "services.pds_api_service.PdsApiService.get_parameters_for_pds_api_request", - return_value=mock_api_request_parameters, + "services.pds_api_service.PdsApiService.get_endpoint_for_pds_api_request", + return_value="api.test/endpoint/", ) mock_new_access_token = mocker.patch( - "services.pds_api_service.PdsApiService.get_new_access_token", - return_value=new_mock_access_token, + "services.pds_api_service.PdsApiService.create_access_token", + return_value=ACCESS_TOKEN, ) mocker.patch("uuid.uuid4", return_value="123412342") pds_service.pds_request(nhs_number="1111111111", retry_on_expired=True) mock_get_parameters.assert_called_once() - mock_new_access_token.assert_called_once() - mock_get_patient_data.get.assert_called_with( - url=mock_url_endpoint, headers=mock_authorization_header - ) - - -def test_pds_request_refresh_token_if_already_expired(mocker, mock_get_patient_data): - time_now = 1600000000 - mocker.patch("time.time", return_value=time_now) - mock_response_token = mock_pds_token_response_issued_at(time_now - 599) - new_mock_access_token = "mock_access_token" - - mock_api_request_parameters = ( - "api.test/endpoint/", - json.dumps(mock_response_token), - ) - nhs_number = "1111111111" - mock_url_endpoint = "api.test/endpoint/Patient/" + nhs_number - mock_authorization_header = { - "Authorization": f"Bearer {new_mock_access_token}", - "X-Request-ID": "123412342", - } - - mock_get_parameters = mocker.patch( - "services.pds_api_service.PdsApiService.get_parameters_for_pds_api_request", - return_value=mock_api_request_parameters, - ) - mock_new_access_token = mocker.patch( - "services.pds_api_service.PdsApiService.get_new_access_token", - return_value=new_mock_access_token, - ) - mocker.patch("uuid.uuid4", return_value="123412342") - - pds_service.pds_request(nhs_number="1111111111", retry_on_expired=True) - - mock_get_parameters.assert_called_once() - mock_new_access_token.assert_called_once() - mock_get_patient_data.get.assert_called_with( - url=mock_url_endpoint, headers=mock_authorization_header - ) - - -def test_pds_request_refresh_token_if_already_expired_11_seconds_ago( - mocker, mock_get_patient_data -): - time_now = 1600000000 - mocker.patch("time.time", return_value=time_now) - mock_response_token = mock_pds_token_response_issued_at(time_now - 599 - 11) - new_mock_access_token = "mock_access_token" - - mock_api_request_parameters = ( - "api.test/endpoint/", - json.dumps(mock_response_token), - ) - nhs_number = "1111111111" - mock_url_endpoint = "api.test/endpoint/Patient/" + nhs_number - mock_authorization_header = { - "Authorization": f"Bearer {new_mock_access_token}", - "X-Request-ID": "123412342", - } - - mock_get_parameters = mocker.patch( - "services.pds_api_service.PdsApiService.get_parameters_for_pds_api_request", - return_value=mock_api_request_parameters, - ) - mock_new_access_token = mocker.patch( - "services.pds_api_service.PdsApiService.get_new_access_token", - return_value=new_mock_access_token, - ) - mocker.patch("uuid.uuid4", return_value="123412342") - - pds_service.pds_request(nhs_number="1111111111", retry_on_expired=True) - - mock_get_parameters.assert_called_once() - mock_new_access_token.assert_called_once() + mock_new_access_token.assert_called() mock_get_patient_data.get.assert_called_with( url=mock_url_endpoint, headers=mock_authorization_header ) def test_pds_request_expired_token(mocker, mock_get_patient_data): - mock_api_request_parameters = ("api.test/endpoint/", json.dumps(RESPONSE_TOKEN)) nhs_number = "1111111111" mock_url_endpoint = "api.test/endpoint/Patient/" + nhs_number new_mock_access_token = "mock_access_token" @@ -377,12 +76,12 @@ def test_pds_request_expired_token(mocker, mock_get_patient_data): } mock_get_parameters = mocker.patch( - "services.pds_api_service.PdsApiService.get_parameters_for_pds_api_request", - return_value=mock_api_request_parameters, + "services.pds_api_service.PdsApiService.get_endpoint_for_pds_api_request", + return_value="api.test/endpoint/", ) mocker.patch("time.time", return_value=1700000000.953031) mock_new_access_token = mocker.patch( - "services.pds_api_service.PdsApiService.get_new_access_token", + "services.pds_api_service.PdsApiService.create_access_token", return_value=new_mock_access_token, ) mocker.patch("uuid.uuid4", return_value="123412342") @@ -404,7 +103,6 @@ def test_pds_request_valid_token_expired_response(mocker): second_response = Response() second_response.status_code = 200 second_response._content = json.dumps(PDS_PATIENT).encode("utf-8") - mock_api_request_parameters = ("api.test/endpoint/", json.dumps(RESPONSE_TOKEN)) nhs_number = "1111111111" mock_url_endpoint = "api.test/endpoint/Patient/" + nhs_number new_mock_access_token = "mock_access_token" @@ -415,12 +113,12 @@ def test_pds_request_valid_token_expired_response(mocker): } mock_get_parameters = mocker.patch( - "services.pds_api_service.PdsApiService.get_parameters_for_pds_api_request", - return_value=mock_api_request_parameters, + "services.pds_api_service.PdsApiService.get_endpoint_for_pds_api_request", + return_value="api.test/endpoint/", ) mocker.patch("time.time", side_effect=[1600000000.953031, 1700000000.953031]) mock_new_access_token = mocker.patch( - "services.pds_api_service.PdsApiService.get_new_access_token", + "services.pds_api_service.PdsApiService.create_access_token", return_value=new_mock_access_token, ) mocker.patch("uuid.uuid4", return_value="123412342") @@ -432,8 +130,7 @@ def test_pds_request_valid_token_expired_response(mocker): assert actual == second_response assert mock_get_parameters.call_count == 2 - mock_new_access_token.assert_called_once() - + assert mock_new_access_token.call_count == 2 mock_session.get.assert_called_with( url=mock_url_endpoint, headers=mock_authorization_header ) @@ -442,20 +139,20 @@ def test_pds_request_valid_token_expired_response(mocker): def test_pds_request_valid_token_expired_response_no_retry(mocker): response = Response() response.status_code = 401 - mock_api_request_parameters = ("api.test/endpoint/", json.dumps(RESPONSE_TOKEN)) nhs_number = "1111111111" mock_url_endpoint = "api.test/endpoint/Patient/" + nhs_number mock_authorization_header = { - "Authorization": f"Bearer {RESPONSE_TOKEN['access_token']}", + "Authorization": f"Bearer {ACCESS_TOKEN}", "X-Request-ID": "123412342", } mock_get_parameters = mocker.patch( - "services.pds_api_service.PdsApiService.get_parameters_for_pds_api_request", - return_value=mock_api_request_parameters, + "services.pds_api_service.PdsApiService.get_endpoint_for_pds_api_request", + return_value="api.test/endpoint/", ) mocker.patch("time.time", return_value=1600000000.953031) mock_new_access_token = mocker.patch( - "services.pds_api_service.PdsApiService.get_new_access_token" + "services.pds_api_service.PdsApiService.create_access_token", + return_value=ACCESS_TOKEN, ) mocker.patch("uuid.uuid4", return_value="123412342") @@ -466,7 +163,7 @@ def test_pds_request_valid_token_expired_response_no_retry(mocker): assert actual == response mock_get_parameters.assert_called_once() - mock_new_access_token.assert_not_called() + mock_new_access_token.assert_called_once() mock_session.get.assert_called_with( url=mock_url_endpoint, headers=mock_authorization_header ) @@ -475,7 +172,7 @@ def test_pds_request_valid_token_expired_response_no_retry(mocker): def test_pds_request_raise_pds_error_exception(mocker): with pytest.raises(PdsErrorException): mock_get_parameters = mocker.patch( - "services.pds_api_service.PdsApiService.get_parameters_for_pds_api_request", + "services.pds_api_service.PdsApiService.get_endpoint_for_pds_api_request", side_effect=ClientError( {"Error": {"Code": "500", "Message": "mocked error"}}, "test" ), diff --git a/lambdas/utils/exceptions.py b/lambdas/utils/exceptions.py index 489643191..d96d810d4 100644 --- a/lambdas/utils/exceptions.py +++ b/lambdas/utils/exceptions.py @@ -6,6 +6,10 @@ class InvalidResourceIdException(Exception): pass +class OAuthErrorException(Exception): + pass + + class PdsErrorException(Exception): pass