Skip to content

Commit

Permalink
[prmp-1120] new nrl service
Browse files Browse the repository at this point in the history
  • Loading branch information
NogaNHS committed Oct 30, 2024
1 parent 8d1dc7d commit 172a6bb
Show file tree
Hide file tree
Showing 7 changed files with 456 additions and 418 deletions.
110 changes: 110 additions & 0 deletions lambdas/services/base/nhs_oauth_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import json
import time
import uuid

import jwt
import requests
from enums.pds_ssm_parameters import SSMParameter
from requests.exceptions import HTTPError
from utils.audit_logging_setup import LoggingService
from utils.exceptions import OAuthErrorException

logger = LoggingService(__name__)


class NhsOauthService:
def __init__(self, ssm_service):
self.ssm_service = ssm_service

def create_access_token(self):
access_token_response = self.get_current_access_token()
access_token_response = json.loads(access_token_response)
access_token = access_token_response["access_token"]
access_token_expiration = (
int(access_token_response["expires_in"])
+ int(access_token_response["issued_at"]) / 1000
)
time_safety_margin_seconds = 10
remaining_time_before_expiration = access_token_expiration - time.time()
if remaining_time_before_expiration < time_safety_margin_seconds:
access_token = self.get_new_access_token()

return access_token

def get_new_access_token(self):
logger.info("Getting new OAuth access token")
try:
access_token_ssm_parameter = self.get_parameters_for_new_access_token()
jwt_token = self.create_jwt_token_for_new_access_token_request(
access_token_ssm_parameter
)
nhs_oauth_endpoint = access_token_ssm_parameter[
SSMParameter.NHS_OAUTH_ENDPOINT.value
]
nhs_oauth_response = self.request_new_access_token(
jwt_token, nhs_oauth_endpoint
)
nhs_oauth_response.raise_for_status()
token_access_response = nhs_oauth_response.json()
self.update_access_token_ssm(json.dumps(token_access_response))
except HTTPError as e:
logger.error(
e.response, {"Result": "Issue while creating new access token"}
)
raise OAuthErrorException("Error creating oauth access token")
return token_access_response["access_token"]

def get_parameters_for_new_access_token(self):
parameters = [
SSMParameter.NHS_OAUTH_ENDPOINT.value,
SSMParameter.PDS_KID.value,
SSMParameter.NHS_OAUTH_KEY.value,
SSMParameter.PDS_API_KEY.value,
]
return self.ssm_service.get_ssm_parameters(parameters, with_decryption=True)

def update_access_token_ssm(self, parameter_value: str):
parameter_key = SSMParameter.PDS_API_ACCESS_TOKEN.value
self.ssm_service.update_ssm_parameter(
parameter_key=parameter_key,
parameter_value=parameter_value,
parameter_type="SecureString",
)

def get_current_access_token(self):
parameters = [
SSMParameter.PDS_API_ACCESS_TOKEN.value,
]
ssm_response = self.ssm_service.get_ssm_parameter(
parameters_keys=parameters, with_decryption=True
)
return ssm_response

def create_jwt_token_for_new_access_token_request(
self, access_token_ssm_parameters
):
nhs_oauth_endpoint = access_token_ssm_parameters[
SSMParameter.NHS_OAUTH_ENDPOINT.value
]
kid = access_token_ssm_parameters[SSMParameter.PDS_KID.value]
nhs_key = access_token_ssm_parameters[SSMParameter.NHS_OAUTH_KEY.value]
pds_key = access_token_ssm_parameters[SSMParameter.PDS_API_KEY.value]
payload = {
"iss": nhs_key,
"sub": nhs_key,
"aud": nhs_oauth_endpoint,
"jti": str(uuid.uuid4()),
"exp": int(time.time()) + 300,
}
return jwt.encode(payload, pds_key, algorithm="RS512", headers={"kid": kid})

def request_new_access_token(self, jwt_token, nhs_oauth_endpoint):
access_token_headers = {"content-type": "application/x-www-form-urlencoded"}
access_token_data = {
"grant_type": "client_credentials",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": jwt_token,
}
return requests.post(
url=nhs_oauth_endpoint, headers=access_token_headers, data=access_token_data
)
30 changes: 30 additions & 0 deletions lambdas/services/nrl_api_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import uuid

from services.base.nhs_oauth_service import NhsOauthService
from utils.audit_logging_setup import LoggingService

logger = LoggingService(__name__)


class NrlApiService(NhsOauthService):
def __init__(self, ssm_service):
super().__init__(ssm_service)
self.headers = {
"Authorization": f"Bearer {self.create_access_token()}",
"Accept": "application/json",
}

def get_api_endpoint(self):
pass

def create_new_pointer(self, body, headers):
self.set_x_request_id()

def update_pointer(self):
self.set_x_request_id()

def delete_pointer(self):
self.set_x_request_id()

def set_x_request_id(self):
self.headers["X-Request-ID"] = uuid.uuid4()
97 changes: 8 additions & 89 deletions lambdas/services/pds_api_service.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import json
import time
import uuid
from json import JSONDecodeError

import jwt
import requests
from botocore.exceptions import ClientError
from enums.pds_ssm_parameters import SSMParameter
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, HTTPError, Timeout
from services.base.nhs_oauth_service import NhsOauthService
from services.patient_search_service import PatientSearch
from urllib3 import Retry
from utils.audit_logging_setup import LoggingService
Expand All @@ -17,9 +15,9 @@
logger = LoggingService(__name__)


class PdsApiService(PatientSearch):
class PdsApiService(PatientSearch, NhsOauthService):
def __init__(self, ssm_service):
self.ssm_service = ssm_service
super().__init__(ssm_service)

retry_strategy = Retry(
total=3,
Expand All @@ -34,17 +32,8 @@ def __init__(self, ssm_service):

def pds_request(self, nhs_number: str, retry_on_expired: bool):
try:
endpoint, access_token_response = self.get_parameters_for_pds_api_request()
access_token_response = json.loads(access_token_response)
access_token = access_token_response["access_token"]
access_token_expiration = (
int(access_token_response["expires_in"])
+ int(access_token_response["issued_at"]) / 1000
)
time_safety_margin_seconds = 10
remaining_time_before_expiration = access_token_expiration - time.time()
if remaining_time_before_expiration < time_safety_margin_seconds:
access_token = self.get_new_access_token()
endpoint = self.get_endpoint_for_pds_api_request()
access_token = self.create_access_token()

x_request_id = str(uuid.uuid4())

Expand All @@ -71,81 +60,11 @@ def pds_request(self, nhs_number: str, retry_on_expired: bool):
logger.error(str(e), {"Result": "Error when calling PDS"})
raise PdsTooManyRequestsException("Failed to perform patient search")

def get_new_access_token(self):
logger.info("Getting new PDS access token")
try:
access_token_ssm_parameter = self.get_parameters_for_new_access_token()
jwt_token = self.create_jwt_token_for_new_access_token_request(
access_token_ssm_parameter
)
nhs_oauth_endpoint = access_token_ssm_parameter[
SSMParameter.NHS_OAUTH_ENDPOINT.value
]
nhs_oauth_response = self.request_new_access_token(
jwt_token, nhs_oauth_endpoint
)
nhs_oauth_response.raise_for_status()
token_access_response = nhs_oauth_response.json()
self.update_access_token_ssm(json.dumps(token_access_response))
except HTTPError as e:
logger.error(
e.response, {"Result": "Issue while creating new access token"}
)
raise PdsErrorException("Error accessing PDS API")
return token_access_response["access_token"]

def get_parameters_for_new_access_token(self):
parameters = [
SSMParameter.NHS_OAUTH_ENDPOINT.value,
SSMParameter.PDS_KID.value,
SSMParameter.NHS_OAUTH_KEY.value,
SSMParameter.PDS_API_KEY.value,
]
return self.ssm_service.get_ssm_parameters(parameters, with_decryption=True)

def update_access_token_ssm(self, parameter_value: str):
parameter_key = SSMParameter.PDS_API_ACCESS_TOKEN.value
self.ssm_service.update_ssm_parameter(
parameter_key=parameter_key,
parameter_value=parameter_value,
parameter_type="SecureString",
)

def get_parameters_for_pds_api_request(self):
def get_endpoint_for_pds_api_request(self):
parameters = [
SSMParameter.PDS_API_ENDPOINT.value,
SSMParameter.PDS_API_ACCESS_TOKEN.value,
]
ssm_response = self.ssm_service.get_ssm_parameters(
ssm_response = self.ssm_service.get_ssm_parameter(
parameters_keys=parameters, with_decryption=True
)
return ssm_response[parameters[0]], ssm_response[parameters[1]]

def create_jwt_token_for_new_access_token_request(
self, access_token_ssm_parameters
):
nhs_oauth_endpoint = access_token_ssm_parameters[
SSMParameter.NHS_OAUTH_ENDPOINT.value
]
kid = access_token_ssm_parameters[SSMParameter.PDS_KID.value]
nhs_key = access_token_ssm_parameters[SSMParameter.NHS_OAUTH_KEY.value]
pds_key = access_token_ssm_parameters[SSMParameter.PDS_API_KEY.value]
payload = {
"iss": nhs_key,
"sub": nhs_key,
"aud": nhs_oauth_endpoint,
"jti": str(uuid.uuid4()),
"exp": int(time.time()) + 300,
}
return jwt.encode(payload, pds_key, algorithm="RS512", headers={"kid": kid})

def request_new_access_token(self, jwt_token, nhs_oauth_endpoint):
access_token_headers = {"content-type": "application/x-www-form-urlencoded"}
access_token_data = {
"grant_type": "client_credentials",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": jwt_token,
}
return requests.post(
url=nhs_oauth_endpoint, headers=access_token_headers, data=access_token_data
)
return ssm_response
3 changes: 3 additions & 0 deletions lambdas/tests/unit/helpers/mock_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ def __init__(self, *arg, **kwargs):
def get_ssm_parameters(self, parameters_keys, *arg, **kwargs):
return {parameter: f"test_value_{parameter}" for parameter in parameters_keys}

def get_ssm_parameter(self, parameters_keys, *arg, **kwargs):
return f"test_value_{parameters_keys[0]}"

def update_ssm_parameter(self, *arg, **kwargs):
pass

Expand Down
Loading

0 comments on commit 172a6bb

Please sign in to comment.