Skip to content

Commit

Permalink
Create method to get creds async without context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
marrobi committed Dec 22, 2023
1 parent 1fd7fa7 commit 0a96818
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 48 deletions.
2 changes: 1 addition & 1 deletion api_app/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.18.2"
__version__ = "0.18.3"
4 changes: 2 additions & 2 deletions api_app/api/dependencies/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient
from fastapi import HTTPException, status
from core.config import MANAGED_IDENTITY_CLIENT_ID, STATE_STORE_ENDPOINT, STATE_STORE_KEY, STATE_STORE_SSL_VERIFY, SUBSCRIPTION_ID, RESOURCE_MANAGER_ENDPOINT, CREDENTIAL_SCOPES, RESOURCE_GROUP_NAME, COSMOSDB_ACCOUNT_NAME
from core.credentials import get_credential
from core.credentials import get_credential_async
from db.errors import UnableToAccessDatabase
from db.repositories.base import BaseRepository
from resources import strings
Expand All @@ -30,7 +30,7 @@ def __init__(self):
async def _connect_to_db(self) -> CosmosClient:
logger.debug(f"Connecting to {STATE_STORE_ENDPOINT}")

credential = get_credential()
credential = await get_credential_async()
if MANAGED_IDENTITY_CLIENT_ID:
logger.debug("Connecting with managed identity")
cosmos_client = CosmosClient(
Expand Down
2 changes: 1 addition & 1 deletion api_app/api/routes/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def health_check(request: Request) -> HealthCheck:
# The health endpoint checks the status of key components of the system.
# Note that Resource Processor checks incur Azure management calls, so
# calling this endpoint frequently may result in API throttling.
async with credentials.get_credential_async() as credential:
async with credentials.get_credential_async_cm() as credential:
cosmos, sb, rp = await asyncio.gather(
create_state_store_status(),
create_service_bus_status(credential),
Expand Down
23 changes: 22 additions & 1 deletion api_app/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,29 @@ def get_credential() -> TokenCredential:
)


@asynccontextmanager
async def get_credential_async() -> TokenCredential:
"""
Context manager which yields the default credentials.
"""
managed_identity = config.MANAGED_IDENTITY_CLIENT_ID
credential = (
ChainedTokenCredentialASync(
ManagedIdentityCredentialASync(client_id=managed_identity)
)
if managed_identity
else DefaultAzureCredentialASync(authority=urlparse(config.AAD_AUTHORITY_URL).netloc,
exclude_shared_token_cache_credential=True,
exclude_workload_identity_credential=True,
exclude_developer_cli_credential=True,
exclude_managed_identity_credential=True,
exclude_powershell_credential=True
)
)
return credential


@asynccontextmanager
async def get_credential_async_cm() -> TokenCredential:
"""
Context manager which yields the default credentials.
"""
Expand Down
2 changes: 1 addition & 1 deletion api_app/event_grid/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


async def publish_event(event: EventGridEvent, topic_endpoint: str):
async with credentials.get_credential_async() as credential:
async with credentials.get_credential_async_cm() as credential:
client = EventGridPublisherClient(topic_endpoint, credential)
async with client:
await client.send([event])
2 changes: 1 addition & 1 deletion api_app/service_bus/airlock_request_status_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def receive_messages(self):
with tracer.start_as_current_span("airlock_receive_messages"):
while True:
try:
async with credentials.get_credential_async() as credential:
async with credentials.get_credential_async_cm() as credential:
service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential)
receiver = service_bus_client.get_queue_receiver(queue_name=config.SERVICE_BUS_STEP_RESULT_QUEUE)
logger.info(f"Looking for new messages on {config.SERVICE_BUS_STEP_RESULT_QUEUE} queue...")
Expand Down
2 changes: 1 addition & 1 deletion api_app/service_bus/deployment_status_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def receive_messages(self):
with tracer.start_as_current_span("deployment_status_receive_messages"):
while True:
try:
async with credentials.get_credential_async() as credential:
async with credentials.get_credential_async_cm() as credential:
service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential)

logger.info(f"Looking for new messages on {config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue...")
Expand Down
2 changes: 1 addition & 1 deletion api_app/service_bus/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def _send_message(message: ServiceBusMessage, queue: str):
:param queue: The Service Bus queue to send the message to.
:type queue: str
"""
async with credentials.get_credential_async() as credential:
async with credentials.get_credential_async_cm() as credential:
service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential)

async with service_bus_client:
Expand Down
8 changes: 4 additions & 4 deletions api_app/tests_ma/test_db/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

@patch("db.events.get_credential")
@patch("db.events.CosmosDBManagementClient")
async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_credential_async_mock):
get_credential_async_mock.return_value = AsyncMock()
async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_credential_async_cm_mock):
get_credential_async_cm_mock.return_value = AsyncMock()
cosmos_db_mgmt_client_mock.return_value = MagicMock()

result = await events.bootstrap_database()
Expand All @@ -19,8 +19,8 @@ async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_creden

@patch("db.events.get_credential")
@patch("db.events.CosmosDBManagementClient")
async def test_bootstrap_database_failure(cosmos_db_mgmt_client_mock, get_credential_async_mock):
get_credential_async_mock.return_value = AsyncMock()
async def test_bootstrap_database_failure(cosmos_db_mgmt_client_mock, get_credential_async_cm_mock):
get_credential_async_cm_mock.return_value = AsyncMock()
cosmos_db_mgmt_client_mock.side_effect = AzureError("some error")

result = await events.bootstrap_database()
Expand Down
65 changes: 30 additions & 35 deletions api_app/tests_ma/test_services/test_health_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,84 +11,79 @@
pytestmark = pytest.mark.asyncio


@patch("core.credentials.get_credential_async")
@patch("api.dependencies.database.Database._get_store_key")
@patch("api.dependencies.database.Database.cosmos_client")
async def test_get_state_store_status_responding(_, get_store_key_mock, get_credential_async) -> None:
async def test_get_state_store_status_responding(_, get_store_key_mock) -> None:
get_store_key_mock.return_value = None
status, message = await health_checker.create_state_store_status(get_credential_async)
status, message = await health_checker.create_state_store_status()

assert status == StatusEnum.ok
assert message == ""


@patch("core.credentials.get_credential_async")
@patch("api.dependencies.database.Database._get_store_key")
@patch("api.dependencies.database.Database.get_db_client")
async def test_get_state_store_status_not_responding(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None:
get_credential_async.return_value = AsyncMock()
async def test_get_state_store_status_not_responding(cosmos_client_mock, get_store_key_mock) -> None:
get_store_key_mock.return_value = None
cosmos_client_mock.return_value = None
cosmos_client_mock.side_effect = ServiceRequestError(message="some message")
status, message = await health_checker.create_state_store_status(get_credential_async)
status, message = await health_checker.create_state_store_status()

assert status == StatusEnum.not_ok
assert message == strings.STATE_STORE_ENDPOINT_NOT_RESPONDING


@patch("core.credentials.get_credential_async")
@patch("api.dependencies.database.Database._get_store_key")
@patch("api.dependencies.database.Database.get_db_client")
async def test_get_state_store_status_other_exception(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None:
get_credential_async.return_value = AsyncMock()
async def test_get_state_store_status_other_exception(cosmos_client_mock, get_store_key_mock) -> None:
get_store_key_mock.return_value = None
cosmos_client_mock.return_value = None
cosmos_client_mock.side_effect = Exception()
status, message = await health_checker.create_state_store_status(get_credential_async)
status, message = await health_checker.create_state_store_status()

assert status == StatusEnum.not_ok
assert message == strings.UNSPECIFIED_ERROR


@patch("core.credentials.get_credential_async")
@patch("core.credentials.get_credential_async_cm")
@patch("services.health_checker.ServiceBusClient")
async def test_get_service_bus_status_responding(service_bus_client_mock, get_credential_async) -> None:
get_credential_async.return_value = AsyncMock()
async def test_get_service_bus_status_responding(service_bus_client_mock, get_credential_async_cm) -> None:
get_credential_async_cm.return_value = AsyncMock()
service_bus_client_mock().get_queue_receiver.__aenter__.return_value = AsyncMock()
status, message = await health_checker.create_service_bus_status(get_credential_async)
status, message = await health_checker.create_service_bus_status(get_credential_async_cm)

assert status == StatusEnum.ok
assert message == ""


@patch("core.credentials.get_credential_async")
@patch("core.credentials.get_credential_async_cm")
@patch("services.health_checker.ServiceBusClient")
async def test_get_service_bus_status_not_responding(service_bus_client_mock, get_credential_async) -> None:
get_credential_async.return_value = AsyncMock()
async def test_get_service_bus_status_not_responding(service_bus_client_mock, get_credential_async_cm) -> None:
get_credential_async_cm.return_value = AsyncMock()
service_bus_client_mock.return_value = None
service_bus_client_mock.side_effect = ServiceBusConnectionError(message="some message")
status, message = await health_checker.create_service_bus_status(get_credential_async)
status, message = await health_checker.create_service_bus_status(get_credential_async_cm)

assert status == StatusEnum.not_ok
assert message == strings.SERVICE_BUS_NOT_RESPONDING


@patch("core.credentials.get_credential_async")
@patch("core.credentials.get_credential_async_cm")
@patch("services.health_checker.ServiceBusClient")
async def test_get_service_bus_status_other_exception(service_bus_client_mock, get_credential_async) -> None:
get_credential_async.return_value = AsyncMock()
async def test_get_service_bus_status_other_exception(service_bus_client_mock, get_credential_async_cm) -> None:
get_credential_async_cm.return_value = AsyncMock()
service_bus_client_mock.return_value = None
service_bus_client_mock.side_effect = Exception()
status, message = await health_checker.create_service_bus_status(get_credential_async)
status, message = await health_checker.create_service_bus_status(get_credential_async_cm)

assert status == StatusEnum.not_ok
assert message == strings.UNSPECIFIED_ERROR


@patch("core.credentials.get_credential_async")
@patch("core.credentials.get_credential_async_cm")
@patch("services.health_checker.ComputeManagementClient")
async def test_get_resource_processor_status_healthy(resource_processor_client_mock, get_credential_async) -> None:
get_credential_async.return_value = AsyncMock()
async def test_get_resource_processor_status_healthy(resource_processor_client_mock, get_credential_async_cm) -> None:
get_credential_async_cm.return_value = AsyncMock()
resource_processor_client_mock().virtual_machine_scale_set_vms.return_value = AsyncMock()
vm_mock = MagicMock()
vm_mock.instance_id = 'mocked_id'
Expand All @@ -100,16 +95,16 @@ async def test_get_resource_processor_status_healthy(resource_processor_client_m
awaited_mock.set_result(instance_view_mock)
resource_processor_client_mock().virtual_machine_scale_set_vms.get_instance_view.return_value = awaited_mock

status, message = await health_checker.create_resource_processor_status(get_credential_async)
status, message = await health_checker.create_resource_processor_status(get_credential_async_cm)

assert status == StatusEnum.ok
assert message == ""


@patch("core.credentials.get_credential_async")
@patch("core.credentials.get_credential_async_cm")
@patch("services.health_checker.ComputeManagementClient", return_value=MagicMock())
async def test_get_resource_processor_status_not_healthy(resource_processor_client_mock, get_credential_async) -> None:
get_credential_async.return_value = AsyncMock()
async def test_get_resource_processor_status_not_healthy(resource_processor_client_mock, get_credential_async_cm) -> None:
get_credential_async_cm.return_value = AsyncMock()

resource_processor_client_mock().virtual_machine_scale_set_vms.return_value = AsyncMock()
vm_mock = MagicMock()
Expand All @@ -122,19 +117,19 @@ async def test_get_resource_processor_status_not_healthy(resource_processor_clie
awaited_mock.set_result(instance_view_mock)
resource_processor_client_mock().virtual_machine_scale_set_vms.get_instance_view.return_value = awaited_mock

status, message = await health_checker.create_resource_processor_status(get_credential_async)
status, message = await health_checker.create_resource_processor_status(get_credential_async_cm)

assert status == StatusEnum.not_ok
assert message == strings.RESOURCE_PROCESSOR_GENERAL_ERROR_MESSAGE


@patch("core.credentials.get_credential_async")
@patch("core.credentials.get_credential_async_cm")
@patch("services.health_checker.ComputeManagementClient")
async def test_get_resource_processor_status_other_exception(resource_processor_client_mock, get_credential_async) -> None:
get_credential_async.return_value = AsyncMock()
async def test_get_resource_processor_status_other_exception(resource_processor_client_mock, get_credential_async_cm) -> None:
get_credential_async_cm.return_value = AsyncMock()
resource_processor_client_mock.return_value = None
resource_processor_client_mock.side_effect = Exception()
status, message = await health_checker.create_resource_processor_status(get_credential_async)
status, message = await health_checker.create_resource_processor_status(get_credential_async_cm)

assert status == StatusEnum.not_ok
assert message == strings.UNSPECIFIED_ERROR
Expand Down

0 comments on commit 0a96818

Please sign in to comment.