From 0a968184d02df9dff2e4d40a5a236f7d08edb0e1 Mon Sep 17 00:00:00 2001 From: marrobi Date: Fri, 22 Dec 2023 15:21:02 +0000 Subject: [PATCH] Create method to get creds async without context manager --- api_app/_version.py | 2 +- api_app/api/dependencies/database.py | 4 +- api_app/api/routes/health.py | 2 +- api_app/core/credentials.py | 23 ++++++- api_app/event_grid/helpers.py | 2 +- .../airlock_request_status_update.py | 2 +- .../service_bus/deployment_status_updater.py | 2 +- api_app/service_bus/helpers.py | 2 +- api_app/tests_ma/test_db/test_events.py | 8 +-- .../test_services/test_health_checker.py | 65 +++++++++---------- 10 files changed, 64 insertions(+), 48 deletions(-) diff --git a/api_app/_version.py b/api_app/_version.py index b0d7306ae4..30e65e30d0 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.18.2" +__version__ = "0.18.3" diff --git a/api_app/api/dependencies/database.py b/api_app/api/dependencies/database.py index 2bac410eb2..ab0486dca1 100644 --- a/api_app/api/dependencies/database.py +++ b/api_app/api/dependencies/database.py @@ -4,7 +4,7 @@ from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient from fastapi import HTTPException, status from core.config import MANAGED_IDENTITY_CLIENT_ID, STATE_STORE_ENDPOINT, STATE_STORE_KEY, STATE_STORE_SSL_VERIFY, SUBSCRIPTION_ID, RESOURCE_MANAGER_ENDPOINT, CREDENTIAL_SCOPES, RESOURCE_GROUP_NAME, COSMOSDB_ACCOUNT_NAME -from core.credentials import get_credential +from core.credentials import get_credential_async from db.errors import UnableToAccessDatabase from db.repositories.base import BaseRepository from resources import strings @@ -30,7 +30,7 @@ def __init__(self): async def _connect_to_db(self) -> CosmosClient: logger.debug(f"Connecting to {STATE_STORE_ENDPOINT}") - credential = get_credential() + credential = await get_credential_async() if MANAGED_IDENTITY_CLIENT_ID: logger.debug("Connecting with managed identity") cosmos_client = CosmosClient( diff --git a/api_app/api/routes/health.py b/api_app/api/routes/health.py index 535087b86a..9d0ef42c12 100644 --- a/api_app/api/routes/health.py +++ b/api_app/api/routes/health.py @@ -14,7 +14,7 @@ async def health_check(request: Request) -> HealthCheck: # The health endpoint checks the status of key components of the system. # Note that Resource Processor checks incur Azure management calls, so # calling this endpoint frequently may result in API throttling. - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_cm() as credential: cosmos, sb, rp = await asyncio.gather( create_state_store_status(), create_service_bus_status(credential), diff --git a/api_app/core/credentials.py b/api_app/core/credentials.py index 8780248a63..331a44423c 100644 --- a/api_app/core/credentials.py +++ b/api_app/core/credentials.py @@ -31,8 +31,29 @@ def get_credential() -> TokenCredential: ) -@asynccontextmanager async def get_credential_async() -> TokenCredential: + """ + Context manager which yields the default credentials. + """ + managed_identity = config.MANAGED_IDENTITY_CLIENT_ID + credential = ( + ChainedTokenCredentialASync( + ManagedIdentityCredentialASync(client_id=managed_identity) + ) + if managed_identity + else DefaultAzureCredentialASync(authority=urlparse(config.AAD_AUTHORITY_URL).netloc, + exclude_shared_token_cache_credential=True, + exclude_workload_identity_credential=True, + exclude_developer_cli_credential=True, + exclude_managed_identity_credential=True, + exclude_powershell_credential=True + ) + ) + return credential + + +@asynccontextmanager +async def get_credential_async_cm() -> TokenCredential: """ Context manager which yields the default credentials. """ diff --git a/api_app/event_grid/helpers.py b/api_app/event_grid/helpers.py index ed96ec695c..debfcfdfc3 100644 --- a/api_app/event_grid/helpers.py +++ b/api_app/event_grid/helpers.py @@ -4,7 +4,7 @@ async def publish_event(event: EventGridEvent, topic_endpoint: str): - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_cm() as credential: client = EventGridPublisherClient(topic_endpoint, credential) async with client: await client.send([event]) diff --git a/api_app/service_bus/airlock_request_status_update.py b/api_app/service_bus/airlock_request_status_update.py index babb302702..7255f3e136 100644 --- a/api_app/service_bus/airlock_request_status_update.py +++ b/api_app/service_bus/airlock_request_status_update.py @@ -32,7 +32,7 @@ async def receive_messages(self): with tracer.start_as_current_span("airlock_receive_messages"): while True: try: - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_cm() as credential: service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential) receiver = service_bus_client.get_queue_receiver(queue_name=config.SERVICE_BUS_STEP_RESULT_QUEUE) logger.info(f"Looking for new messages on {config.SERVICE_BUS_STEP_RESULT_QUEUE} queue...") diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 255b25d1c3..26ebfe9309 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -41,7 +41,7 @@ async def receive_messages(self): with tracer.start_as_current_span("deployment_status_receive_messages"): while True: try: - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_cm() as credential: service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential) logger.info(f"Looking for new messages on {config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue...") diff --git a/api_app/service_bus/helpers.py b/api_app/service_bus/helpers.py index 55ae5c1b20..79f11b687b 100644 --- a/api_app/service_bus/helpers.py +++ b/api_app/service_bus/helpers.py @@ -24,7 +24,7 @@ async def _send_message(message: ServiceBusMessage, queue: str): :param queue: The Service Bus queue to send the message to. :type queue: str """ - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_cm() as credential: service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential) async with service_bus_client: diff --git a/api_app/tests_ma/test_db/test_events.py b/api_app/tests_ma/test_db/test_events.py index 7b93a16a82..016067b105 100644 --- a/api_app/tests_ma/test_db/test_events.py +++ b/api_app/tests_ma/test_db/test_events.py @@ -8,8 +8,8 @@ @patch("db.events.get_credential") @patch("db.events.CosmosDBManagementClient") -async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_credential_async_mock): - get_credential_async_mock.return_value = AsyncMock() +async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_credential_async_cm_mock): + get_credential_async_cm_mock.return_value = AsyncMock() cosmos_db_mgmt_client_mock.return_value = MagicMock() result = await events.bootstrap_database() @@ -19,8 +19,8 @@ async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_creden @patch("db.events.get_credential") @patch("db.events.CosmosDBManagementClient") -async def test_bootstrap_database_failure(cosmos_db_mgmt_client_mock, get_credential_async_mock): - get_credential_async_mock.return_value = AsyncMock() +async def test_bootstrap_database_failure(cosmos_db_mgmt_client_mock, get_credential_async_cm_mock): + get_credential_async_cm_mock.return_value = AsyncMock() cosmos_db_mgmt_client_mock.side_effect = AzureError("some error") result = await events.bootstrap_database() diff --git a/api_app/tests_ma/test_services/test_health_checker.py b/api_app/tests_ma/test_services/test_health_checker.py index 61e76b85c4..d918b5c87e 100644 --- a/api_app/tests_ma/test_services/test_health_checker.py +++ b/api_app/tests_ma/test_services/test_health_checker.py @@ -11,84 +11,79 @@ pytestmark = pytest.mark.asyncio -@patch("core.credentials.get_credential_async") @patch("api.dependencies.database.Database._get_store_key") @patch("api.dependencies.database.Database.cosmos_client") -async def test_get_state_store_status_responding(_, get_store_key_mock, get_credential_async) -> None: +async def test_get_state_store_status_responding(_, get_store_key_mock) -> None: get_store_key_mock.return_value = None - status, message = await health_checker.create_state_store_status(get_credential_async) + status, message = await health_checker.create_state_store_status() assert status == StatusEnum.ok assert message == "" -@patch("core.credentials.get_credential_async") @patch("api.dependencies.database.Database._get_store_key") @patch("api.dependencies.database.Database.get_db_client") -async def test_get_state_store_status_not_responding(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_state_store_status_not_responding(cosmos_client_mock, get_store_key_mock) -> None: get_store_key_mock.return_value = None cosmos_client_mock.return_value = None cosmos_client_mock.side_effect = ServiceRequestError(message="some message") - status, message = await health_checker.create_state_store_status(get_credential_async) + status, message = await health_checker.create_state_store_status() assert status == StatusEnum.not_ok assert message == strings.STATE_STORE_ENDPOINT_NOT_RESPONDING -@patch("core.credentials.get_credential_async") @patch("api.dependencies.database.Database._get_store_key") @patch("api.dependencies.database.Database.get_db_client") -async def test_get_state_store_status_other_exception(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_state_store_status_other_exception(cosmos_client_mock, get_store_key_mock) -> None: get_store_key_mock.return_value = None cosmos_client_mock.return_value = None cosmos_client_mock.side_effect = Exception() - status, message = await health_checker.create_state_store_status(get_credential_async) + status, message = await health_checker.create_state_store_status() assert status == StatusEnum.not_ok assert message == strings.UNSPECIFIED_ERROR -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_cm") @patch("services.health_checker.ServiceBusClient") -async def test_get_service_bus_status_responding(service_bus_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_service_bus_status_responding(service_bus_client_mock, get_credential_async_cm) -> None: + get_credential_async_cm.return_value = AsyncMock() service_bus_client_mock().get_queue_receiver.__aenter__.return_value = AsyncMock() - status, message = await health_checker.create_service_bus_status(get_credential_async) + status, message = await health_checker.create_service_bus_status(get_credential_async_cm) assert status == StatusEnum.ok assert message == "" -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_cm") @patch("services.health_checker.ServiceBusClient") -async def test_get_service_bus_status_not_responding(service_bus_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_service_bus_status_not_responding(service_bus_client_mock, get_credential_async_cm) -> None: + get_credential_async_cm.return_value = AsyncMock() service_bus_client_mock.return_value = None service_bus_client_mock.side_effect = ServiceBusConnectionError(message="some message") - status, message = await health_checker.create_service_bus_status(get_credential_async) + status, message = await health_checker.create_service_bus_status(get_credential_async_cm) assert status == StatusEnum.not_ok assert message == strings.SERVICE_BUS_NOT_RESPONDING -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_cm") @patch("services.health_checker.ServiceBusClient") -async def test_get_service_bus_status_other_exception(service_bus_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_service_bus_status_other_exception(service_bus_client_mock, get_credential_async_cm) -> None: + get_credential_async_cm.return_value = AsyncMock() service_bus_client_mock.return_value = None service_bus_client_mock.side_effect = Exception() - status, message = await health_checker.create_service_bus_status(get_credential_async) + status, message = await health_checker.create_service_bus_status(get_credential_async_cm) assert status == StatusEnum.not_ok assert message == strings.UNSPECIFIED_ERROR -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_cm") @patch("services.health_checker.ComputeManagementClient") -async def test_get_resource_processor_status_healthy(resource_processor_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_resource_processor_status_healthy(resource_processor_client_mock, get_credential_async_cm) -> None: + get_credential_async_cm.return_value = AsyncMock() resource_processor_client_mock().virtual_machine_scale_set_vms.return_value = AsyncMock() vm_mock = MagicMock() vm_mock.instance_id = 'mocked_id' @@ -100,16 +95,16 @@ async def test_get_resource_processor_status_healthy(resource_processor_client_m awaited_mock.set_result(instance_view_mock) resource_processor_client_mock().virtual_machine_scale_set_vms.get_instance_view.return_value = awaited_mock - status, message = await health_checker.create_resource_processor_status(get_credential_async) + status, message = await health_checker.create_resource_processor_status(get_credential_async_cm) assert status == StatusEnum.ok assert message == "" -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_cm") @patch("services.health_checker.ComputeManagementClient", return_value=MagicMock()) -async def test_get_resource_processor_status_not_healthy(resource_processor_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_resource_processor_status_not_healthy(resource_processor_client_mock, get_credential_async_cm) -> None: + get_credential_async_cm.return_value = AsyncMock() resource_processor_client_mock().virtual_machine_scale_set_vms.return_value = AsyncMock() vm_mock = MagicMock() @@ -122,19 +117,19 @@ async def test_get_resource_processor_status_not_healthy(resource_processor_clie awaited_mock.set_result(instance_view_mock) resource_processor_client_mock().virtual_machine_scale_set_vms.get_instance_view.return_value = awaited_mock - status, message = await health_checker.create_resource_processor_status(get_credential_async) + status, message = await health_checker.create_resource_processor_status(get_credential_async_cm) assert status == StatusEnum.not_ok assert message == strings.RESOURCE_PROCESSOR_GENERAL_ERROR_MESSAGE -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_cm") @patch("services.health_checker.ComputeManagementClient") -async def test_get_resource_processor_status_other_exception(resource_processor_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_resource_processor_status_other_exception(resource_processor_client_mock, get_credential_async_cm) -> None: + get_credential_async_cm.return_value = AsyncMock() resource_processor_client_mock.return_value = None resource_processor_client_mock.side_effect = Exception() - status, message = await health_checker.create_resource_processor_status(get_credential_async) + status, message = await health_checker.create_resource_processor_status(get_credential_async_cm) assert status == StatusEnum.not_ok assert message == strings.UNSPECIFIED_ERROR