From 5071e4594cbad4a437e7df0f0e1fc5c957d967b9 Mon Sep 17 00:00:00 2001 From: marrobi Date: Wed, 6 Dec 2023 09:42:58 +0000 Subject: [PATCH 01/32] Move to using managed identity for auth to CosmosDB. --- api_app/_version.py | 2 +- api_app/api/dependencies/database.py | 40 +++++++++++++++++----------- api_app/db/events.py | 3 ++- api_app/services/health_checker.py | 7 ++--- 4 files changed, 29 insertions(+), 23 deletions(-) diff --git a/api_app/_version.py b/api_app/_version.py index 8bc9d14811..7913b9a0f4 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.16.9" +__version__ = "0.17.7" diff --git a/api_app/api/dependencies/database.py b/api_app/api/dependencies/database.py index 40c064b30f..c1ab40dc09 100644 --- a/api_app/api/dependencies/database.py +++ b/api_app/api/dependencies/database.py @@ -5,6 +5,7 @@ from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient from fastapi import Depends, FastAPI, HTTPException from fastapi import Request, status +from core.config import MANAGED_IDENTITY_CLIENT_ID from core import config, credentials from db.errors import UnableToAccessDatabase from db.repositories.base import BaseRepository @@ -12,28 +13,35 @@ async def connect_to_db() -> CosmosClient: - logging.debug(f"Connecting to {config.STATE_STORE_ENDPOINT}") - - try: - async with credentials.get_credential_async() as credential: - primary_master_key = await get_store_key(credential) - - if config.STATE_STORE_SSL_VERIFY: + logger.debug(f"Connecting to {config.STATE_STORE_ENDPOINT}") + async with credentials.get_credential_async() as credential: + if MANAGED_IDENTITY_CLIENT_ID: + logger.debug("Connecting with managed identity") cosmos_client = CosmosClient( - url=config.STATE_STORE_ENDPOINT, credential=primary_master_key + url=config.STATE_STORE_ENDPOINT, + credential=credential ) else: - # ignore TLS (setup is a pain) when using local Cosmos emulator. - cosmos_client = CosmosClient( - config.STATE_STORE_ENDPOINT, primary_master_key, connection_verify=False - ) - logging.debug("Connection established") - return cosmos_client - except Exception: - logging.exception("Connection to state store could not be established.") + logger.debug("Connecting with key") + primary_master_key = await get_store_key(credential) + + if config.STATE_STORE_SSL_VERIFY: + logger.debug("Connecting with SSL verification") + cosmos_client = CosmosClient( + url=config.STATE_STORE_ENDPOINT, credential=primary_master_key + ) + else: + logger.debug("Connecting without SSL verification") + # ignore TLS (setup is a pain) when using local Cosmos emulator. + cosmos_client = CosmosClient( + config.STATE_STORE_ENDPOINT, primary_master_key, connection_verify=False + ) + logger.debug("Connection established") + return cosmos_client async def get_store_key(credential) -> str: + logger.debug("Getting store key") if config.STATE_STORE_KEY: primary_master_key = config.STATE_STORE_KEY else: diff --git a/api_app/db/events.py b/api_app/db/events.py index f1fb407482..02ba9eaf7d 100644 --- a/api_app/db/events.py +++ b/api_app/db/events.py @@ -16,5 +16,6 @@ async def bootstrap_database(app) -> bool: await ResourceRepository.create(client) return True except Exception as e: - logging.debug(e) + logger.exception("Could not bootstrap database") + logger.debug(e) return False diff --git a/api_app/services/health_checker.py b/api_app/services/health_checker.py index 42a01e9ff0..576b6e7eb8 100644 --- a/api_app/services/health_checker.py +++ b/api_app/services/health_checker.py @@ -1,12 +1,11 @@ import logging from typing import Tuple from azure.core import exceptions -from azure.cosmos.aio import CosmosClient from azure.servicebus.aio import ServiceBusClient from azure.mgmt.compute.aio import ComputeManagementClient from azure.cosmos.exceptions import CosmosHttpResponseError from azure.servicebus.exceptions import ServiceBusConnectionError, ServiceBusAuthenticationError -from api.dependencies.database import get_store_key +from api.dependencies.database import connect_to_db from core import config from models.schemas.status import StatusEnum @@ -16,10 +15,8 @@ async def create_state_store_status(credential) -> Tuple[StatusEnum, str]: status = StatusEnum.ok message = "" - debug = True if config.DEBUG == "true" else False try: - primary_master_key = await get_store_key(credential) - cosmos_client = CosmosClient(config.STATE_STORE_ENDPOINT, primary_master_key, connection_verify=debug) + cosmos_client = connect_to_db() async with cosmos_client: list_databases_response = cosmos_client.list_databases() [database async for database in list_databases_response] From 81ceb1ed595fd0bee0782a585577b7b5467e998d Mon Sep 17 00:00:00 2001 From: marrobi Date: Wed, 6 Dec 2023 10:46:40 +0000 Subject: [PATCH 02/32] Add permissions to TRE DB to the API MSI --- core/terraform/api-identity.tf | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/core/terraform/api-identity.tf b/core/terraform/api-identity.tf index c68213f816..8209e37143 100644 --- a/core/terraform/api-identity.tf +++ b/core/terraform/api-identity.tf @@ -45,3 +45,16 @@ resource "azurerm_role_assignment" "cosmos_contributor" { principal_id = azurerm_user_assigned_identity.id.principal_id } +data "azurerm_cosmosdb_sql_role_definition" "cosmosdb_db_contributor" { + resource_group_name = azurerm_resource_group.core.name + account_name = azurerm_cosmosdb_account.tre_db_account.name + role_definition_id = "00000000-0000-0000-0000-000000000002" # Cosmos DB Built-in Data Contributor +} + +resource "azurerm_cosmosdb_sql_role_assignment" "tre_db_contributor" { + resource_group_name = azurerm_resource_group.core.name + account_name = azurerm_cosmosdb_account.tre_db_account.name + role_definition_id = data.azurerm_cosmosdb_sql_role_definition.cosmosdb_db_contributor.id + principal_id = azurerm_user_assigned_identity.id.principal_id + scope = azurerm_cosmosdb_account.tre_db_account.id +} From 4ff3dfe72759914b7170852b06e04e3b44e212a8 Mon Sep 17 00:00:00 2001 From: marrobi Date: Thu, 7 Dec 2023 17:06:44 +0000 Subject: [PATCH 03/32] Update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6de49e9d2..86a16d06c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ FEATURES: ENHANCEMENTS: * Switch from OpenCensus to OpenTelemetry for logging ([#3762](https://github.com/microsoft/AzureTRE/pull/3762)) +* Use mangaged identity for API connection to CosmosDB ([#345](https://github.com/microsoft/AzureTRE/issues/345)) BUG FIXES: From 45a9fb7d7a8e10daa3a89e0f54f8b60516d850bc Mon Sep 17 00:00:00 2001 From: marrobi Date: Wed, 20 Dec 2023 11:48:27 +0000 Subject: [PATCH 04/32] Add missing await and fix tests. --- CHANGELOG.md | 2 +- api_app/services/health_checker.py | 2 +- .../tests_ma/test_services/test_health_checker.py | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 86a16d06c9..dfb80db153 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ FEATURES: ENHANCEMENTS: * Switch from OpenCensus to OpenTelemetry for logging ([#3762](https://github.com/microsoft/AzureTRE/pull/3762)) -* Use mangaged identity for API connection to CosmosDB ([#345](https://github.com/microsoft/AzureTRE/issues/345)) +* Use managed identity for API connection to CosmosDB ([#345](https://github.com/microsoft/AzureTRE/issues/345)) BUG FIXES: diff --git a/api_app/services/health_checker.py b/api_app/services/health_checker.py index 9798a92e06..057414626c 100644 --- a/api_app/services/health_checker.py +++ b/api_app/services/health_checker.py @@ -16,7 +16,7 @@ async def create_state_store_status(credential) -> Tuple[StatusEnum, str]: status = StatusEnum.ok message = "" try: - cosmos_client = connect_to_db() + cosmos_client = await connect_to_db() async with cosmos_client: list_databases_response = cosmos_client.list_databases() [database async for database in list_databases_response] diff --git a/api_app/tests_ma/test_services/test_health_checker.py b/api_app/tests_ma/test_services/test_health_checker.py index 84ff3ae989..fd9221016b 100644 --- a/api_app/tests_ma/test_services/test_health_checker.py +++ b/api_app/tests_ma/test_services/test_health_checker.py @@ -12,8 +12,8 @@ @patch("core.credentials.get_credential_async") -@patch("services.health_checker.get_store_key") -@patch("services.health_checker.CosmosClient") +@patch("api.dependencies.database.get_store_key") +@patch("api.dependencies.database.CosmosClient") async def test_get_state_store_status_responding(_, get_store_key_mock, get_credential_async) -> None: get_store_key_mock.return_value = None status, message = await health_checker.create_state_store_status(get_credential_async) @@ -23,8 +23,8 @@ async def test_get_state_store_status_responding(_, get_store_key_mock, get_cred @patch("core.credentials.get_credential_async") -@patch("services.health_checker.get_store_key") -@patch("services.health_checker.CosmosClient") +@patch("api.dependencies.database.get_store_key") +@patch("api.dependencies.database.CosmosClient") async def test_get_state_store_status_not_responding(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None: get_credential_async.return_value = AsyncMock() get_store_key_mock.return_value = None @@ -37,8 +37,8 @@ async def test_get_state_store_status_not_responding(cosmos_client_mock, get_sto @patch("core.credentials.get_credential_async") -@patch("services.health_checker.get_store_key") -@patch("services.health_checker.CosmosClient") +@patch("api.dependencies.database.get_store_key") +@patch("api.dependencies.database.CosmosClient") async def test_get_state_store_status_other_exception(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None: get_credential_async.return_value = AsyncMock() get_store_key_mock.return_value = None From 3914d31c5189434166345a6c38aa1d7010f92c0e Mon Sep 17 00:00:00 2001 From: marrobi Date: Wed, 20 Dec 2023 11:53:09 +0000 Subject: [PATCH 05/32] update core version --- core/version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/version.txt b/core/version.txt index d69d16e980..a2fecb4576 100644 --- a/core/version.txt +++ b/core/version.txt @@ -1 +1 @@ -__version__ = "0.9.1" +__version__ = "0.9.2" From fc2dd9a71d9ca7367b13f726e24dcecb2a408b86 Mon Sep 17 00:00:00 2001 From: marrobi Date: Wed, 20 Dec 2023 11:55:23 +0000 Subject: [PATCH 06/32] Update core version --- core/version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/version.txt b/core/version.txt index a2fecb4576..c5981731c5 100644 --- a/core/version.txt +++ b/core/version.txt @@ -1 +1 @@ -__version__ = "0.9.2" +__version__ = "0.9.3" From 5e7a41757980561c6aed1dff48698dca167af98f Mon Sep 17 00:00:00 2001 From: marrobi Date: Wed, 20 Dec 2023 17:31:53 +0000 Subject: [PATCH 07/32] remove DB create as this is done in terraform --- api_app/_version.py | 2 +- api_app/db/events.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/api_app/_version.py b/api_app/_version.py index 7913b9a0f4..782f49be9b 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.17.7" +__version__ = "0.17.8" diff --git a/api_app/db/events.py b/api_app/db/events.py index bc25d57bf5..a86a9dc6f7 100644 --- a/api_app/db/events.py +++ b/api_app/db/events.py @@ -2,7 +2,6 @@ from api.dependencies.database import get_db_client from db.repositories.resources import ResourceRepository -from core import config from services.logging import logger @@ -10,7 +9,6 @@ async def bootstrap_database(app) -> bool: try: client: CosmosClient = await get_db_client(app) if client: - await client.create_database_if_not_exists(id=config.STATE_STORE_DATABASE) # Test access to database await ResourceRepository.create(client) return True From 97018a2fe99e4e97e770edb5e1d3ac5bf4a82e4e Mon Sep 17 00:00:00 2001 From: marrobi Date: Thu, 21 Dec 2023 12:00:36 +0000 Subject: [PATCH 08/32] split management plane and data plane operations. --- api_app/_version.py | 2 +- api_app/api/routes/health.py | 6 +-- api_app/db/events.py | 46 +++++++++++++++---- api_app/db/repositories/base.py | 2 +- api_app/main.py | 2 +- api_app/services/health_checker.py | 6 +-- api_app/tests_ma/test_db/test_events.py | 25 ++++++++++ .../test_services/test_health_checker.py | 4 +- 8 files changed, 73 insertions(+), 20 deletions(-) create mode 100644 api_app/tests_ma/test_db/test_events.py diff --git a/api_app/_version.py b/api_app/_version.py index 782f49be9b..a85820bb10 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.17.8" +__version__ = "0.17.12" diff --git a/api_app/api/routes/health.py b/api_app/api/routes/health.py index 301a6fd54d..e0bab1f33b 100644 --- a/api_app/api/routes/health.py +++ b/api_app/api/routes/health.py @@ -1,5 +1,5 @@ import asyncio -from fastapi import APIRouter +from fastapi import APIRouter, Request from core import credentials from models.schemas.status import HealthCheck, ServiceStatus, StatusEnum from resources import strings @@ -10,13 +10,13 @@ @router.get("/health", name=strings.API_GET_HEALTH_STATUS) -async def health_check() -> HealthCheck: +async def health_check(request: Request) -> HealthCheck: # The health endpoint checks the status of key components of the system. # Note that Resource Processor checks incur Azure management calls, so # calling this endpoint frequently may result in API throttling. async with credentials.get_credential_async() as credential: cosmos, sb, rp = await asyncio.gather( - create_state_store_status(credential), + create_state_store_status(request), create_service_bus_status(credential), create_resource_processor_status(credential) ) diff --git a/api_app/db/events.py b/api_app/db/events.py index a86a9dc6f7..f03c997462 100644 --- a/api_app/db/events.py +++ b/api_app/db/events.py @@ -1,17 +1,45 @@ -from azure.cosmos.aio import CosmosClient +from azure.mgmt.cosmosdb import CosmosDBManagementClient -from api.dependencies.database import get_db_client -from db.repositories.resources import ResourceRepository +from core.config import SUBSCRIPTION_ID, RESOURCE_GROUP_NAME, RESOURCE_LOCATION, COSMOSDB_ACCOUNT_NAME, STATE_STORE_DATABASE, STATE_STORE_RESOURCES_CONTAINER, STATE_STORE_RESOURCE_TEMPLATES_CONTAINER, STATE_STORE_RESOURCES_HISTORY_CONTAINER, STATE_STORE_OPERATIONS_CONTAINER, STATE_STORE_AIRLOCK_REQUESTS_CONTAINER +from core.credentials import get_credential from services.logging import logger -async def bootstrap_database(app) -> bool: +async def bootstrap_database() -> bool: try: - client: CosmosClient = await get_db_client(app) - if client: - # Test access to database - await ResourceRepository.create(client) - return True + credential = get_credential() + db_mgmt_client = CosmosDBManagementClient(credential=credential, subscription_id=SUBSCRIPTION_ID) + + repository_containers = [ + STATE_STORE_RESOURCES_CONTAINER, + STATE_STORE_RESOURCE_TEMPLATES_CONTAINER, + STATE_STORE_RESOURCES_HISTORY_CONTAINER, + STATE_STORE_OPERATIONS_CONTAINER, + STATE_STORE_AIRLOCK_REQUESTS_CONTAINER + ] + + for container in repository_containers: + # create container if it doesn't exist + db_mgmt_client.sql_resources.begin_create_update_sql_container( + resource_group_name=RESOURCE_GROUP_NAME, + account_name=COSMOSDB_ACCOUNT_NAME, + database_name=STATE_STORE_DATABASE, + container_name=container, + create_update_sql_container_parameters={ + "location": RESOURCE_LOCATION, + "resource": { + "id": container, + "partition_key": { + "paths": [ + "/id" + ], + "kind": "Hash" + } + } + } + ) + return True + except Exception as e: logger.exception("Could not bootstrap database") logger.debug(e) diff --git a/api_app/db/repositories/base.py b/api_app/db/repositories/base.py index 35395fa064..fcecb068eb 100644 --- a/api_app/db/repositories/base.py +++ b/api_app/db/repositories/base.py @@ -24,7 +24,7 @@ def container(self) -> ContainerProxy: async def _get_container(cls, container_name, partition_key_obj) -> ContainerProxy: try: database = cls._client.get_database_client(config.STATE_STORE_DATABASE) - container = await database.create_container_if_not_exists(id=container_name, partition_key=partition_key_obj) + container = database.get_container_client(container=container_name) return container except Exception: raise UnableToAccessDatabase diff --git a/api_app/main.py b/api_app/main.py index 703b8f4d77..b7e7fa8aaf 100644 --- a/api_app/main.py +++ b/api_app/main.py @@ -26,7 +26,7 @@ async def lifespan(app: FastAPI): app.state.cosmos_client = None - while not await bootstrap_database(app): + while not await bootstrap_database(): await asyncio.sleep(5) logger.warning("Database connection could not be established") diff --git a/api_app/services/health_checker.py b/api_app/services/health_checker.py index 057414626c..aa8a5013a2 100644 --- a/api_app/services/health_checker.py +++ b/api_app/services/health_checker.py @@ -4,7 +4,7 @@ from azure.mgmt.compute.aio import ComputeManagementClient from azure.cosmos.exceptions import CosmosHttpResponseError from azure.servicebus.exceptions import ServiceBusConnectionError, ServiceBusAuthenticationError -from api.dependencies.database import connect_to_db +from api.dependencies.database import get_db_client_from_request from core import config from models.schemas.status import StatusEnum @@ -12,11 +12,11 @@ from services.logging import logger -async def create_state_store_status(credential) -> Tuple[StatusEnum, str]: +async def create_state_store_status(request) -> Tuple[StatusEnum, str]: status = StatusEnum.ok message = "" try: - cosmos_client = await connect_to_db() + cosmos_client = await get_db_client_from_request(request) async with cosmos_client: list_databases_response = cosmos_client.list_databases() [database async for database in list_databases_response] diff --git a/api_app/tests_ma/test_db/test_events.py b/api_app/tests_ma/test_db/test_events.py new file mode 100644 index 0000000000..a464b29442 --- /dev/null +++ b/api_app/tests_ma/test_db/test_events.py @@ -0,0 +1,25 @@ +from unittest.mock import AsyncMock, MagicMock, patch +from azure.core.exceptions import AzureError +from api_app.db import events + + +@patch("api_app.db.events.get_credential_async") +@patch("api_app.db.events.CosmosDBManagementClient") +async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_credential_async_mock): + get_credential_async_mock.return_value = AsyncMock() + cosmos_db_mgmt_client_mock.return_value = MagicMock() + + result = await events.bootstrap_database() + + assert result is True + + +@patch("api_app.db.events.get_credential_async") +@patch("api_app.db.events.CosmosDBManagementClient") +async def test_bootstrap_database_failure(cosmos_db_mgmt_client_mock, get_credential_async_mock): + get_credential_async_mock.return_value = AsyncMock() + cosmos_db_mgmt_client_mock.side_effect = AzureError() + + result = await events.bootstrap_database() + + assert result is False diff --git a/api_app/tests_ma/test_services/test_health_checker.py b/api_app/tests_ma/test_services/test_health_checker.py index fd9221016b..76587ef134 100644 --- a/api_app/tests_ma/test_services/test_health_checker.py +++ b/api_app/tests_ma/test_services/test_health_checker.py @@ -24,7 +24,7 @@ async def test_get_state_store_status_responding(_, get_store_key_mock, get_cred @patch("core.credentials.get_credential_async") @patch("api.dependencies.database.get_store_key") -@patch("api.dependencies.database.CosmosClient") +@patch("api.dependencies.database.get_db_client") async def test_get_state_store_status_not_responding(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None: get_credential_async.return_value = AsyncMock() get_store_key_mock.return_value = None @@ -38,7 +38,7 @@ async def test_get_state_store_status_not_responding(cosmos_client_mock, get_sto @patch("core.credentials.get_credential_async") @patch("api.dependencies.database.get_store_key") -@patch("api.dependencies.database.CosmosClient") +@patch("api.dependencies.database.get_db_client") async def test_get_state_store_status_other_exception(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None: get_credential_async.return_value = AsyncMock() get_store_key_mock.return_value = None From 11f0f16fd5b3dce31d7e7eaf3ec92df4db8fe036 Mon Sep 17 00:00:00 2001 From: marrobi Date: Thu, 21 Dec 2023 12:07:59 +0000 Subject: [PATCH 09/32] fix test --- api_app/tests_ma/test_db/test_events.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/api_app/tests_ma/test_db/test_events.py b/api_app/tests_ma/test_db/test_events.py index a464b29442..bcef3e3c2a 100644 --- a/api_app/tests_ma/test_db/test_events.py +++ b/api_app/tests_ma/test_db/test_events.py @@ -1,9 +1,11 @@ from unittest.mock import AsyncMock, MagicMock, patch from azure.core.exceptions import AzureError +import pytest from api_app.db import events +pytestmark = pytest.mark.asyncio -@patch("api_app.db.events.get_credential_async") +@patch("api_app.db.events.get_credential") @patch("api_app.db.events.CosmosDBManagementClient") async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_credential_async_mock): get_credential_async_mock.return_value = AsyncMock() @@ -14,11 +16,11 @@ async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_creden assert result is True -@patch("api_app.db.events.get_credential_async") +@patch("api_app.db.events.get_credential") @patch("api_app.db.events.CosmosDBManagementClient") async def test_bootstrap_database_failure(cosmos_db_mgmt_client_mock, get_credential_async_mock): get_credential_async_mock.return_value = AsyncMock() - cosmos_db_mgmt_client_mock.side_effect = AzureError() + cosmos_db_mgmt_client_mock.side_effect = AzureError("some error") result = await events.bootstrap_database() From d0752e4290d25d5058e19f851c167c9650b13289 Mon Sep 17 00:00:00 2001 From: marrobi Date: Thu, 21 Dec 2023 12:13:15 +0000 Subject: [PATCH 10/32] fix test --- api_app/tests_ma/test_db/test_events.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/api_app/tests_ma/test_db/test_events.py b/api_app/tests_ma/test_db/test_events.py index bcef3e3c2a..7b93a16a82 100644 --- a/api_app/tests_ma/test_db/test_events.py +++ b/api_app/tests_ma/test_db/test_events.py @@ -1,12 +1,13 @@ from unittest.mock import AsyncMock, MagicMock, patch from azure.core.exceptions import AzureError import pytest -from api_app.db import events +from db import events pytestmark = pytest.mark.asyncio -@patch("api_app.db.events.get_credential") -@patch("api_app.db.events.CosmosDBManagementClient") + +@patch("db.events.get_credential") +@patch("db.events.CosmosDBManagementClient") async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_credential_async_mock): get_credential_async_mock.return_value = AsyncMock() cosmos_db_mgmt_client_mock.return_value = MagicMock() @@ -16,8 +17,8 @@ async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_creden assert result is True -@patch("api_app.db.events.get_credential") -@patch("api_app.db.events.CosmosDBManagementClient") +@patch("db.events.get_credential") +@patch("db.events.CosmosDBManagementClient") async def test_bootstrap_database_failure(cosmos_db_mgmt_client_mock, get_credential_async_mock): get_credential_async_mock.return_value = AsyncMock() cosmos_db_mgmt_client_mock.side_effect = AzureError("some error") From c345aa26fe5aed03b0fdbf5fd11b73a040c8dc2d Mon Sep 17 00:00:00 2001 From: marrobi Date: Thu, 21 Dec 2023 12:36:42 +0000 Subject: [PATCH 11/32] Refactor container creation. --- api_app/db/events.py | 56 ++++++++++---------- api_app/db/repositories/base.py | 7 ++- api_app/db/repositories/resources_history.py | 2 +- 3 files changed, 32 insertions(+), 33 deletions(-) diff --git a/api_app/db/events.py b/api_app/db/events.py index f03c997462..c6b79a730a 100644 --- a/api_app/db/events.py +++ b/api_app/db/events.py @@ -10,37 +10,37 @@ async def bootstrap_database() -> bool: credential = get_credential() db_mgmt_client = CosmosDBManagementClient(credential=credential, subscription_id=SUBSCRIPTION_ID) - repository_containers = [ - STATE_STORE_RESOURCES_CONTAINER, - STATE_STORE_RESOURCE_TEMPLATES_CONTAINER, - STATE_STORE_RESOURCES_HISTORY_CONTAINER, - STATE_STORE_OPERATIONS_CONTAINER, - STATE_STORE_AIRLOCK_REQUESTS_CONTAINER - ] - - for container in repository_containers: - # create container if it doesn't exist - db_mgmt_client.sql_resources.begin_create_update_sql_container( - resource_group_name=RESOURCE_GROUP_NAME, - account_name=COSMOSDB_ACCOUNT_NAME, - database_name=STATE_STORE_DATABASE, - container_name=container, - create_update_sql_container_parameters={ - "location": RESOURCE_LOCATION, - "resource": { - "id": container, - "partition_key": { - "paths": [ - "/id" - ], - "kind": "Hash" - } - } - } - ) + create_container_if_not_exists(db_mgmt_client, STATE_STORE_RESOURCES_CONTAINER, "/id") + create_container_if_not_exists(db_mgmt_client, STATE_STORE_RESOURCE_TEMPLATES_CONTAINER, "/id") + create_container_if_not_exists(db_mgmt_client, STATE_STORE_RESOURCES_HISTORY_CONTAINER, "/resourceId") + create_container_if_not_exists(db_mgmt_client, STATE_STORE_OPERATIONS_CONTAINER, "/id") + create_container_if_not_exists(db_mgmt_client, STATE_STORE_AIRLOCK_REQUESTS_CONTAINER, "/id") + return True except Exception as e: logger.exception("Could not bootstrap database") logger.debug(e) return False + + +async def create_container_if_not_exists(db_mgmt_client, container, partition_key): + + db_mgmt_client.sql_resources.begin_create_update_sql_container( + resource_group_name=RESOURCE_GROUP_NAME, + account_name=COSMOSDB_ACCOUNT_NAME, + database_name=STATE_STORE_DATABASE, + container_name=container, + create_update_sql_container_parameters={ + "location": RESOURCE_LOCATION, + "resource": { + "id": container, + "partition_key": { + "paths": [ + partition_key + ], + "kind": "Hash" + } + } + } + ) diff --git a/api_app/db/repositories/base.py b/api_app/db/repositories/base.py index fcecb068eb..0ee4b1b678 100644 --- a/api_app/db/repositories/base.py +++ b/api_app/db/repositories/base.py @@ -10,10 +10,9 @@ class BaseRepository: @classmethod - async def create(cls, client: CosmosClient, container_name: Optional[str] = None, partition_key: str = "/id"): - partition_key_obj = PartitionKey(path=partition_key) + async def create(cls, client: CosmosClient, container_name: Optional[str] = None): cls._client: CosmosClient = client - cls._container: ContainerProxy = await cls._get_container(container_name, partition_key_obj) + cls._container: ContainerProxy = await cls._get_container(container_name) return cls @property @@ -21,7 +20,7 @@ def container(self) -> ContainerProxy: return self._container @classmethod - async def _get_container(cls, container_name, partition_key_obj) -> ContainerProxy: + async def _get_container(cls, container_name) -> ContainerProxy: try: database = cls._client.get_database_client(config.STATE_STORE_DATABASE) container = database.get_container_client(container=container_name) diff --git a/api_app/db/repositories/resources_history.py b/api_app/db/repositories/resources_history.py index 2a6524d62d..6005619b39 100644 --- a/api_app/db/repositories/resources_history.py +++ b/api_app/db/repositories/resources_history.py @@ -14,7 +14,7 @@ class ResourceHistoryRepository(BaseRepository): @classmethod async def create(cls, client: CosmosClient): cls = ResourceHistoryRepository() - await super().create(client, config.STATE_STORE_RESOURCES_HISTORY_CONTAINER, "/resourceId") + await super().create(client, config.STATE_STORE_RESOURCES_HISTORY_CONTAINER) return cls @staticmethod From 0f6f592fefe0269b33cdc3a57cfb324d2a63e26f Mon Sep 17 00:00:00 2001 From: marrobi Date: Thu, 21 Dec 2023 12:37:08 +0000 Subject: [PATCH 12/32] remove import --- api_app/db/repositories/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api_app/db/repositories/base.py b/api_app/db/repositories/base.py index 0ee4b1b678..15ea9bf0fc 100644 --- a/api_app/db/repositories/base.py +++ b/api_app/db/repositories/base.py @@ -1,6 +1,5 @@ from typing import Optional from azure.cosmos.aio import CosmosClient, ContainerProxy -from azure.cosmos import PartitionKey from azure.core import MatchConditions from pydantic import BaseModel From f87f2bc292ed0f7397a1fd0664b0c9cfcb64725f Mon Sep 17 00:00:00 2001 From: marrobi Date: Thu, 21 Dec 2023 16:43:22 +0000 Subject: [PATCH 13/32] Added workaround for transport being closed. --- api_app/_version.py | 2 +- api_app/api/dependencies/database.py | 51 +++++++++++++++++----------- api_app/db/events.py | 13 ++++--- 3 files changed, 41 insertions(+), 25 deletions(-) diff --git a/api_app/_version.py b/api_app/_version.py index a85820bb10..ddf3e33504 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.17.12" +__version__ = "0.17.13" diff --git a/api_app/api/dependencies/database.py b/api_app/api/dependencies/database.py index 790627f15a..318c2b8996 100644 --- a/api_app/api/dependencies/database.py +++ b/api_app/api/dependencies/database.py @@ -2,10 +2,9 @@ from azure.cosmos.aio import CosmosClient from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient -from fastapi import Depends, FastAPI, HTTPException -from fastapi import Request, status -from core.config import MANAGED_IDENTITY_CLIENT_ID -from core import config, credentials +from fastapi import Depends, FastAPI, HTTPException, Request, status +from core.config import MANAGED_IDENTITY_CLIENT_ID, STATE_STORE_ENDPOINT, STATE_STORE_KEY, STATE_STORE_SSL_VERIFY, SUBSCRIPTION_ID, RESOURCE_MANAGER_ENDPOINT, CREDENTIAL_SCOPES, RESOURCE_GROUP_NAME, COSMOSDB_ACCOUNT_NAME +from core.credentials import get_credential_async from db.errors import UnableToAccessDatabase from db.repositories.base import BaseRepository from resources import strings @@ -13,28 +12,29 @@ async def connect_to_db() -> CosmosClient: - logger.debug(f"Connecting to {config.STATE_STORE_ENDPOINT}") - async with credentials.get_credential_async() as credential: + logger.debug(f"Connecting to {STATE_STORE_ENDPOINT}") + + async with get_credential_async() as credential: if MANAGED_IDENTITY_CLIENT_ID: logger.debug("Connecting with managed identity") cosmos_client = CosmosClient( - url=config.STATE_STORE_ENDPOINT, + url=STATE_STORE_ENDPOINT, credential=credential ) else: logger.debug("Connecting with key") primary_master_key = await get_store_key(credential) - if config.STATE_STORE_SSL_VERIFY: + if STATE_STORE_SSL_VERIFY: logger.debug("Connecting with SSL verification") cosmos_client = CosmosClient( - url=config.STATE_STORE_ENDPOINT, credential=primary_master_key + url=STATE_STORE_ENDPOINT, credential=primary_master_key ) else: logger.debug("Connecting without SSL verification") # ignore TLS (setup is a pain) when using local Cosmos emulator. cosmos_client = CosmosClient( - config.STATE_STORE_ENDPOINT, primary_master_key, connection_verify=False + STATE_STORE_ENDPOINT, primary_master_key, connection_verify=False ) logger.debug("Connection established") return cosmos_client @@ -42,18 +42,18 @@ async def connect_to_db() -> CosmosClient: async def get_store_key(credential) -> str: logger.debug("Getting store key") - if config.STATE_STORE_KEY: - primary_master_key = config.STATE_STORE_KEY + if STATE_STORE_KEY: + primary_master_key = STATE_STORE_KEY else: async with CosmosDBManagementClient( credential, - subscription_id=config.SUBSCRIPTION_ID, - base_url=config.RESOURCE_MANAGER_ENDPOINT, - credential_scopes=config.CREDENTIAL_SCOPES + subscription_id=SUBSCRIPTION_ID, + base_url=RESOURCE_MANAGER_ENDPOINT, + credential_scopes=CREDENTIAL_SCOPES ) as cosmosdb_mng_client: database_keys = await cosmosdb_mng_client.database_accounts.list_keys( - resource_group_name=config.RESOURCE_GROUP_NAME, - account_name=config.COSMOSDB_ACCOUNT_NAME, + resource_group_name=RESOURCE_GROUP_NAME, + account_name=COSMOSDB_ACCOUNT_NAME, ) primary_master_key = database_keys.primary_master_key @@ -61,8 +61,21 @@ async def get_store_key(credential) -> str: async def get_db_client(app: FastAPI) -> CosmosClient: - if not hasattr(app.state, 'cosmos_client') or not app.state.cosmos_client: - app.state.cosmos_client = await connect_to_db() + logger.debug("Getting cosmos client") + cosmos_client = None + if hasattr(app.state, 'cosmos_client') and app.state.cosmos_client: + logger.debug("Cosmos client found in state") + cosmos_client = app.state.cosmos_client + # TODO: if session is closed recreate - need to investigate why this is happening + # https://github.com/Azure/azure-sdk-for-python/issues/32309 + if hasattr(cosmos_client.client_connection, "session") and not cosmos_client.client_connection.session: + logger.debug("Cosmos client session is None") + cosmos_client = await connect_to_db() + else: + logger.debug("No cosmos client found, creating one") + cosmos_client = await connect_to_db() + + app.state.cosmos_client = cosmos_client return app.state.cosmos_client diff --git a/api_app/db/events.py b/api_app/db/events.py index c6b79a730a..462af10ed5 100644 --- a/api_app/db/events.py +++ b/api_app/db/events.py @@ -1,3 +1,4 @@ +import asyncio from azure.mgmt.cosmosdb import CosmosDBManagementClient from core.config import SUBSCRIPTION_ID, RESOURCE_GROUP_NAME, RESOURCE_LOCATION, COSMOSDB_ACCOUNT_NAME, STATE_STORE_DATABASE, STATE_STORE_RESOURCES_CONTAINER, STATE_STORE_RESOURCE_TEMPLATES_CONTAINER, STATE_STORE_RESOURCES_HISTORY_CONTAINER, STATE_STORE_OPERATIONS_CONTAINER, STATE_STORE_AIRLOCK_REQUESTS_CONTAINER @@ -10,11 +11,13 @@ async def bootstrap_database() -> bool: credential = get_credential() db_mgmt_client = CosmosDBManagementClient(credential=credential, subscription_id=SUBSCRIPTION_ID) - create_container_if_not_exists(db_mgmt_client, STATE_STORE_RESOURCES_CONTAINER, "/id") - create_container_if_not_exists(db_mgmt_client, STATE_STORE_RESOURCE_TEMPLATES_CONTAINER, "/id") - create_container_if_not_exists(db_mgmt_client, STATE_STORE_RESOURCES_HISTORY_CONTAINER, "/resourceId") - create_container_if_not_exists(db_mgmt_client, STATE_STORE_OPERATIONS_CONTAINER, "/id") - create_container_if_not_exists(db_mgmt_client, STATE_STORE_AIRLOCK_REQUESTS_CONTAINER, "/id") + await asyncio.gather( + create_container_if_not_exists(db_mgmt_client, STATE_STORE_RESOURCES_CONTAINER, "/id"), + create_container_if_not_exists(db_mgmt_client, STATE_STORE_RESOURCE_TEMPLATES_CONTAINER, "/id"), + create_container_if_not_exists(db_mgmt_client, STATE_STORE_RESOURCES_HISTORY_CONTAINER, "/resourceId"), + create_container_if_not_exists(db_mgmt_client, STATE_STORE_OPERATIONS_CONTAINER, "/id"), + create_container_if_not_exists(db_mgmt_client, STATE_STORE_AIRLOCK_REQUESTS_CONTAINER, "/id") + ) return True From 69c8568e46dbf2253502aacbb65824a960cb126d Mon Sep 17 00:00:00 2001 From: marrobi Date: Fri, 22 Dec 2023 11:22:19 +0000 Subject: [PATCH 14/32] Move database connection to singleton --- api_app/api/dependencies/airlock.py | 4 +- api_app/api/dependencies/database.py | 165 +++++++++--------- api_app/api/dependencies/shared_services.py | 6 +- .../workspace_service_templates.py | 4 +- api_app/api/dependencies/workspaces.py | 14 +- api_app/api/routes/airlock.py | 46 ++--- api_app/api/routes/api.py | 6 +- api_app/api/routes/costs.py | 12 +- api_app/api/routes/migrations.py | 16 +- api_app/api/routes/operations.py | 4 +- .../api/routes/shared_service_templates.py | 8 +- api_app/api/routes/shared_services.py | 18 +- api_app/api/routes/user_resource_templates.py | 8 +- .../api/routes/workspace_service_templates.py | 8 +- api_app/api/routes/workspace_templates.py | 8 +- api_app/api/routes/workspaces.py | 80 ++++----- api_app/main.py | 4 +- .../airlock_request_status_update.py | 8 +- .../service_bus/deployment_status_updater.py | 8 +- api_app/services/aad_authentication.py | 5 +- api_app/services/health_checker.py | 4 +- api_app/tests_ma/conftest.py | 4 +- api_app/tests_ma/test_api/conftest.py | 4 +- .../test_api/test_routes/test_workspaces.py | 4 +- .../test_airlock_request_status_update.py | 36 ++-- .../test_deployment_status_update.py | 60 +++---- .../test_services/test_health_checker.py | 12 +- 27 files changed, 278 insertions(+), 278 deletions(-) diff --git a/api_app/api/dependencies/airlock.py b/api_app/api/dependencies/airlock.py index 4a15aa2741..efca378996 100644 --- a/api_app/api/dependencies/airlock.py +++ b/api_app/api/dependencies/airlock.py @@ -1,7 +1,7 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 -from api.dependencies.database import get_repository +from api.dependencies.database import Database from db.repositories.airlock_requests import AirlockRequestRepository from models.domain.airlock_request import AirlockRequest from db.errors import EntityDoesNotExist, UnableToAccessDatabase @@ -17,5 +17,5 @@ async def get_airlock_request_by_id(airlock_request_id: UUID4, airlock_request_r raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=strings.STATE_STORE_ENDPOINT_NOT_RESPONDING) -async def get_airlock_request_by_id_from_path(airlock_request_id: UUID4 = Path(...), airlock_request_repo=Depends(get_repository(AirlockRequestRepository))) -> AirlockRequest: +async def get_airlock_request_by_id_from_path(airlock_request_id: UUID4 = Path(...), airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository))) -> AirlockRequest: return await get_airlock_request_by_id(airlock_request_id, airlock_request_repo) diff --git a/api_app/api/dependencies/database.py b/api_app/api/dependencies/database.py index 318c2b8996..373be00df1 100644 --- a/api_app/api/dependencies/database.py +++ b/api_app/api/dependencies/database.py @@ -11,91 +11,90 @@ from services.logging import logger -async def connect_to_db() -> CosmosClient: - logger.debug(f"Connecting to {STATE_STORE_ENDPOINT}") - - async with get_credential_async() as credential: - if MANAGED_IDENTITY_CLIENT_ID: - logger.debug("Connecting with managed identity") - cosmos_client = CosmosClient( - url=STATE_STORE_ENDPOINT, - credential=credential - ) - else: - logger.debug("Connecting with key") - primary_master_key = await get_store_key(credential) +class Singleton(type): + _instances = {} + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class Database(metaclass=Singleton): + cosmos_client = None + + def __init__(self): + pass + + @classmethod + async def _connect_to_db(self) -> CosmosClient: + logger.debug(f"Connecting to {STATE_STORE_ENDPOINT}") - if STATE_STORE_SSL_VERIFY: - logger.debug("Connecting with SSL verification") + async with get_credential_async() as credential: + if MANAGED_IDENTITY_CLIENT_ID: + logger.debug("Connecting with managed identity") cosmos_client = CosmosClient( - url=STATE_STORE_ENDPOINT, credential=primary_master_key + url=STATE_STORE_ENDPOINT, + credential=credential ) else: - logger.debug("Connecting without SSL verification") - # ignore TLS (setup is a pain) when using local Cosmos emulator. - cosmos_client = CosmosClient( - STATE_STORE_ENDPOINT, primary_master_key, connection_verify=False + logger.debug("Connecting with key") + primary_master_key = await self._get_store_key(credential) + + if STATE_STORE_SSL_VERIFY: + logger.debug("Connecting with SSL verification") + cosmos_client = CosmosClient( + url=STATE_STORE_ENDPOINT, credential=primary_master_key + ) + else: + logger.debug("Connecting without SSL verification") + # ignore TLS (setup is a pain) when using local Cosmos emulator. + cosmos_client = CosmosClient( + STATE_STORE_ENDPOINT, primary_master_key, connection_verify=False + ) + logger.debug("Connection established") + return cosmos_client + + @classmethod + async def _get_store_key(self, credential) -> str: + logger.debug("Getting store key") + if STATE_STORE_KEY: + primary_master_key = STATE_STORE_KEY + else: + async with CosmosDBManagementClient( + credential, + subscription_id=SUBSCRIPTION_ID, + base_url=RESOURCE_MANAGER_ENDPOINT, + credential_scopes=CREDENTIAL_SCOPES + ) as cosmosdb_mng_client: + database_keys = await cosmosdb_mng_client.database_accounts.list_keys( + resource_group_name=RESOURCE_GROUP_NAME, + account_name=COSMOSDB_ACCOUNT_NAME, ) - logger.debug("Connection established") - return cosmos_client - - -async def get_store_key(credential) -> str: - logger.debug("Getting store key") - if STATE_STORE_KEY: - primary_master_key = STATE_STORE_KEY - else: - async with CosmosDBManagementClient( - credential, - subscription_id=SUBSCRIPTION_ID, - base_url=RESOURCE_MANAGER_ENDPOINT, - credential_scopes=CREDENTIAL_SCOPES - ) as cosmosdb_mng_client: - database_keys = await cosmosdb_mng_client.database_accounts.list_keys( - resource_group_name=RESOURCE_GROUP_NAME, - account_name=COSMOSDB_ACCOUNT_NAME, - ) - primary_master_key = database_keys.primary_master_key - - return primary_master_key - - -async def get_db_client(app: FastAPI) -> CosmosClient: - logger.debug("Getting cosmos client") - cosmos_client = None - if hasattr(app.state, 'cosmos_client') and app.state.cosmos_client: - logger.debug("Cosmos client found in state") - cosmos_client = app.state.cosmos_client - # TODO: if session is closed recreate - need to investigate why this is happening - # https://github.com/Azure/azure-sdk-for-python/issues/32309 - if hasattr(cosmos_client.client_connection, "session") and not cosmos_client.client_connection.session: - logger.debug("Cosmos client session is None") - cosmos_client = await connect_to_db() - else: - logger.debug("No cosmos client found, creating one") - cosmos_client = await connect_to_db() - - app.state.cosmos_client = cosmos_client - return app.state.cosmos_client - - -async def get_db_client_from_request(request: Request) -> CosmosClient: - return await get_db_client(request.app) - - -def get_repository( - repo_type: Type[BaseRepository], -) -> Callable[[CosmosClient], BaseRepository]: - async def _get_repo( - client: CosmosClient = Depends(get_db_client_from_request), - ) -> BaseRepository: - try: - return await repo_type.create(client) - except UnableToAccessDatabase: - logger.exception(strings.STATE_STORE_ENDPOINT_NOT_RESPONDING) - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=strings.STATE_STORE_ENDPOINT_NOT_RESPONDING, - ) - - return _get_repo + primary_master_key = database_keys.primary_master_key + + return primary_master_key + + @classmethod + async def get_db_client(self) -> CosmosClient: + logger.debug("Getting cosmos client") + if not Database.cosmos_client: + Database.cosmos_client = await self._connect_to_db() + return self.cosmos_client + + + @classmethod + def get_repository(self, + repo_type: Type[BaseRepository], + ) -> Callable[[CosmosClient], BaseRepository]: + + async def _get_repo() -> BaseRepository: + try: + return await repo_type.create(self.cosmos_client) + except UnableToAccessDatabase: + logger.exception(strings.STATE_STORE_ENDPOINT_NOT_RESPONDING) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=strings.STATE_STORE_ENDPOINT_NOT_RESPONDING, + ) + + return _get_repo diff --git a/api_app/api/dependencies/shared_services.py b/api_app/api/dependencies/shared_services.py index 970f120776..388ec8a3e5 100644 --- a/api_app/api/dependencies/shared_services.py +++ b/api_app/api/dependencies/shared_services.py @@ -1,7 +1,7 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 -from api.dependencies.database import get_repository +from api.dependencies.database import Database from db.errors import EntityDoesNotExist from resources import strings from models.domain.shared_service import SharedService @@ -17,11 +17,11 @@ async def get_shared_service_by_id(shared_service_id: UUID4, shared_services_rep raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.SHARED_SERVICE_DOES_NOT_EXIST) -async def get_shared_service_by_id_from_path(shared_service_id: UUID4 = Path(...), shared_service_repo=Depends(get_repository(SharedServiceRepository))) -> SharedService: +async def get_shared_service_by_id_from_path(shared_service_id: UUID4 = Path(...), shared_service_repo=Depends(Database().get_repository(SharedServiceRepository))) -> SharedService: return await get_shared_service_by_id(shared_service_id, shared_service_repo) -async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(get_repository(OperationRepository))) -> Operation: +async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(Database().get_repository(OperationRepository))) -> Operation: try: return await operations_repo.get_operation_by_id(operation_id=operation_id) except EntityDoesNotExist: diff --git a/api_app/api/dependencies/workspace_service_templates.py b/api_app/api/dependencies/workspace_service_templates.py index 17d049d153..2a6e908722 100644 --- a/api_app/api/dependencies/workspace_service_templates.py +++ b/api_app/api/dependencies/workspace_service_templates.py @@ -1,6 +1,6 @@ from fastapi import Depends, HTTPException, Path, status -from api.dependencies.database import get_repository +from api.dependencies.database import Database from db.errors import EntityDoesNotExist from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType @@ -8,7 +8,7 @@ from resources import strings -async def get_workspace_service_template_by_name_from_path(service_template_name: str = Path(...), template_repo=Depends(get_repository(ResourceTemplateRepository))) -> ResourceTemplate: +async def get_workspace_service_template_by_name_from_path(service_template_name: str = Path(...), template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> ResourceTemplate: try: return await template_repo.get_current_template(service_template_name, ResourceType.WorkspaceService) except EntityDoesNotExist: diff --git a/api_app/api/dependencies/workspaces.py b/api_app/api/dependencies/workspaces.py index 40f566845f..90f98c63b5 100644 --- a/api_app/api/dependencies/workspaces.py +++ b/api_app/api/dependencies/workspaces.py @@ -1,7 +1,7 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 -from api.dependencies.database import get_repository +from api.dependencies.database import Database from db.errors import EntityDoesNotExist, ResourceIsNotDeployed from db.repositories.operations import OperationRepository from db.repositories.user_resources import UserResourceRepository @@ -22,11 +22,11 @@ async def get_workspace_by_id(workspace_id: UUID4, workspaces_repo) -> Workspace raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.WORKSPACE_DOES_NOT_EXIST) -async def get_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(get_repository(WorkspaceRepository))) -> Workspace: +async def get_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(Database().get_repository(WorkspaceRepository))) -> Workspace: return await get_workspace_by_id(workspace_id, workspaces_repo) -async def get_deployed_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(get_repository(WorkspaceRepository)), operations_repo=Depends(get_repository(OperationRepository))) -> Workspace: +async def get_deployed_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(Database().get_repository(WorkspaceRepository)), operations_repo=Depends(Database().get_repository(OperationRepository))) -> Workspace: try: return await workspaces_repo.get_deployed_workspace_by_id(workspace_id, operations_repo) except EntityDoesNotExist: @@ -35,14 +35,14 @@ async def get_deployed_workspace_by_id_from_path(workspace_id: UUID4 = Path(...) raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=strings.WORKSPACE_IS_NOT_DEPLOYED) -async def get_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(get_repository(WorkspaceServiceRepository))) -> WorkspaceService: +async def get_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(Database().get_repository(WorkspaceServiceRepository))) -> WorkspaceService: try: return await workspace_services_repo.get_workspace_service_by_id(workspace_id, service_id) except EntityDoesNotExist: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.WORKSPACE_SERVICE_DOES_NOT_EXIST) -async def get_deployed_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(get_repository(WorkspaceServiceRepository)), operations_repo=Depends(get_repository(OperationRepository))) -> WorkspaceService: +async def get_deployed_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), operations_repo=Depends(Database().get_repository(OperationRepository))) -> WorkspaceService: try: return await workspace_services_repo.get_deployed_workspace_service_by_id(workspace_id, service_id, operations_repo) except EntityDoesNotExist: @@ -51,14 +51,14 @@ async def get_deployed_workspace_service_by_id_from_path(workspace_id: UUID4 = P raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=strings.WORKSPACE_SERVICE_IS_NOT_DEPLOYED) -async def get_user_resource_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), resource_id: UUID4 = Path(...), user_resource_repo=Depends(get_repository(UserResourceRepository))) -> UserResource: +async def get_user_resource_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), resource_id: UUID4 = Path(...), user_resource_repo=Depends(Database().get_repository(UserResourceRepository))) -> UserResource: try: return await user_resource_repo.get_user_resource_by_id(workspace_id, service_id, resource_id) except EntityDoesNotExist: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.USER_RESOURCE_DOES_NOT_EXIST) -async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(get_repository(OperationRepository))) -> Operation: +async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(Database().get_repository(OperationRepository))) -> Operation: try: return await operations_repo.get_operation_by_id(operation_id=operation_id) except EntityDoesNotExist: diff --git a/api_app/api/routes/airlock.py b/api_app/api/routes/airlock.py index 9fa9790e97..900b418867 100644 --- a/api_app/api/routes/airlock.py +++ b/api_app/api/routes/airlock.py @@ -11,7 +11,7 @@ from db.repositories.airlock_requests import AirlockRequestRepository from db.errors import EntityDoesNotExist, UserNotAuthorizedToUseTemplate -from api.dependencies.database import get_repository +from api.dependencies.database import Database from api.dependencies.workspaces import get_workspace_by_id_from_path, get_deployed_workspace_by_id_from_path from api.dependencies.airlock import get_airlock_request_by_id_from_path from models.domain.airlock_request import AirlockRequestStatus, AirlockRequestType @@ -36,7 +36,7 @@ response_model=AirlockRequestWithAllowedUserActions, name=strings.API_CREATE_AIRLOCK_REQUEST, dependencies=[Depends(get_current_workspace_owner_or_researcher_user), Depends(get_workspace_by_id_from_path)]) async def create_draft_request(airlock_request_input: AirlockRequestInCreate, user=Depends(get_current_workspace_owner_or_researcher_user), - airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), + airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), workspace=Depends(get_deployed_workspace_by_id_from_path)) -> AirlockRequestWithAllowedUserActions: if workspace.properties.get("enable_airlock") is False: raise HTTPException(status_code=status_code.HTTP_405_METHOD_NOT_ALLOWED, detail=strings.AIRLOCK_NOT_ENABLED_IN_WORKSPACE) @@ -57,7 +57,7 @@ async def create_draft_request(airlock_request_input: AirlockRequestInCreate, us dependencies=[Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) async def get_all_airlock_requests_by_workspace( - airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), + airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), workspace=Depends(get_deployed_workspace_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), creator_user_id: Optional[str] = None, type: Optional[AirlockRequestType] = None, status: Optional[AirlockRequestStatus] = None, @@ -77,7 +77,7 @@ async def get_all_airlock_requests_by_workspace( response_model=AirlockRequestWithAllowedUserActions, name=strings.API_GET_AIRLOCK_REQUEST, dependencies=[Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) async def retrieve_airlock_request_by_id(airlock_request=Depends(get_airlock_request_by_id_from_path), - airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), + airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager)) -> AirlockRequestWithAllowedUserActions: allowed_actions = get_allowed_actions(airlock_request, user, airlock_request_repo) return AirlockRequestWithAllowedUserActions(airlockRequest=airlock_request, allowedUserActions=allowed_actions) @@ -88,7 +88,7 @@ async def retrieve_airlock_request_by_id(airlock_request=Depends(get_airlock_req dependencies=[Depends(get_current_workspace_owner_or_researcher_user), Depends(get_workspace_by_id_from_path)]) async def create_submit_request(airlock_request=Depends(get_airlock_request_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user), - airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), + airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), workspace=Depends(get_workspace_by_id_from_path)) -> AirlockRequestWithAllowedUserActions: updated_request = await update_and_publish_event_airlock_request(airlock_request, airlock_request_repo, user, workspace, new_status=AirlockRequestStatus.Submitted) @@ -102,12 +102,12 @@ async def create_submit_request(airlock_request=Depends(get_airlock_request_by_i async def create_cancel_request(airlock_request=Depends(get_airlock_request_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user), workspace=Depends(get_workspace_by_id_from_path), - airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), - user_resource_repo=Depends(get_repository(UserResourceRepository)), - workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), - resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), - operation_repo=Depends(get_repository(OperationRepository)), - resource_template_repo=Depends(get_repository(ResourceTemplateRepository)),) -> AirlockRequestWithAllowedUserActions: + airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), + user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), + workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), + resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), + operation_repo=Depends(Database().get_repository(OperationRepository)), + resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)),) -> AirlockRequestWithAllowedUserActions: updated_request = await cancel_request(airlock_request, user, workspace, airlock_request_repo, user_resource_repo, workspace_service_repo, resource_template_repo, operation_repo, resource_history_repo) allowed_actions = get_allowed_actions(updated_request, user, airlock_request_repo) return AirlockRequestWithAllowedUserActions(airlockRequest=updated_request, allowedUserActions=allowed_actions) @@ -122,12 +122,12 @@ async def create_review_user_resource( airlock_request=Depends(get_airlock_request_by_id_from_path), user=Depends(get_current_airlock_manager_user), workspace=Depends(get_deployed_workspace_by_id_from_path), - user_resource_repo=Depends(get_repository(UserResourceRepository)), - workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), - operation_repo=Depends(get_repository(OperationRepository)), - airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), - resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), - resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> AirlockRequestAndOperationInResponse: + user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), + workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), + operation_repo=Depends(Database().get_repository(OperationRepository)), + airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), + resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), + resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> AirlockRequestAndOperationInResponse: if airlock_request.status != AirlockRequestStatus.InReview: raise HTTPException(status_code=status_code.HTTP_400_BAD_REQUEST, @@ -160,12 +160,12 @@ async def create_airlock_review( airlock_request=Depends(get_airlock_request_by_id_from_path), user=Depends(get_current_airlock_manager_user), workspace=Depends(get_deployed_workspace_by_id_from_path), - airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), - user_resource_repo=Depends(get_repository(UserResourceRepository)), - workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), - operation_repo=Depends(get_repository(OperationRepository)), - resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), - resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> AirlockRequestWithAllowedUserActions: + airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), + user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), + workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), + operation_repo=Depends(Database().get_repository(OperationRepository)), + resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), + resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> AirlockRequestWithAllowedUserActions: try: updated_airlock_request = await review_airlock_request(airlock_review_input, airlock_request, user, workspace, airlock_request_repo, user_resource_repo, workspace_service_repo, operation_repo, resource_template_repo, resource_history_repo) allowed_actions = get_allowed_actions(updated_airlock_request, user, airlock_request_repo) diff --git a/api_app/api/routes/api.py b/api_app/api/routes/api.py index 5cf104f976..025717bc7a 100644 --- a/api_app/api/routes/api.py +++ b/api_app/api/routes/api.py @@ -5,7 +5,7 @@ from fastapi.openapi.docs import get_swagger_ui_html, get_swagger_ui_oauth2_redirect_html from fastapi.openapi.utils import get_openapi -from api.dependencies.database import get_repository +from api.dependencies.database import Database from db.repositories.workspaces import WorkspaceRepository from api.routes import health, ping, workspaces, workspace_templates, workspace_service_templates, user_resource_templates, \ shared_services, shared_service_templates, migrations, costs, airlock, operations, metadata @@ -116,7 +116,7 @@ def get_scope(workspace) -> str: @workspace_swagger_router.get("/workspaces/{workspace_id}/openapi.json", include_in_schema=False, name="openapi_definitions") -async def get_openapi_json(workspace_id: str, request: Request, workspace_repo=Depends(get_repository(WorkspaceRepository))): +async def get_openapi_json(workspace_id: str, request: Request, workspace_repo=Depends(Database().get_repository(WorkspaceRepository))): global openapi_definitions if openapi_definitions[workspace_id] is None: @@ -146,7 +146,7 @@ async def get_openapi_json(workspace_id: str, request: Request, workspace_repo=D @workspace_swagger_router.get("/workspaces/{workspace_id}/docs", include_in_schema=False, name="workspace_swagger") -async def get_workspace_swagger(workspace_id, request: Request, workspace_repo=Depends(get_repository(WorkspaceRepository))): +async def get_workspace_swagger(workspace_id, request: Request, workspace_repo=Depends(Database().get_repository(WorkspaceRepository))): workspace = await workspace_repo.get_workspace_by_id(workspace_id) scope = get_scope(workspace) diff --git a/api_app/api/routes/costs.py b/api_app/api/routes/costs.py index b6684756f2..97d353bbb3 100644 --- a/api_app/api/routes/costs.py +++ b/api_app/api/routes/costs.py @@ -7,7 +7,7 @@ from pydantic import UUID4 from models.schemas.costs import get_cost_report_responses, get_workspace_cost_report_responses -from api.dependencies.database import get_repository +from api.dependencies.database import Database from core import config from db.repositories.shared_services import SharedServiceRepository from db.repositories.user_resources import UserResourceRepository @@ -56,8 +56,8 @@ def __init__( async def costs( params: CostsQueryParams = Depends(), cost_service: CostService = Depends(cost_service_factory), - workspace_repo=Depends(get_repository(WorkspaceRepository)), - shared_services_repo=Depends(get_repository(SharedServiceRepository))) -> CostReport: + workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), + shared_services_repo=Depends(Database().get_repository(SharedServiceRepository))) -> CostReport: validate_report_period(params.from_date, params.to_date) try: @@ -90,9 +90,9 @@ async def costs( responses=get_workspace_cost_report_responses()) async def workspace_costs(workspace_id: UUID4, params: CostsQueryParams = Depends(), cost_service: CostService = Depends(cost_service_factory), - workspace_repo=Depends(get_repository(WorkspaceRepository)), - workspace_services_repo=Depends(get_repository(WorkspaceServiceRepository)), - user_resource_repo=Depends(get_repository(UserResourceRepository))) -> WorkspaceCostReport: + workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), + workspace_services_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), + user_resource_repo=Depends(Database().get_repository(UserResourceRepository))) -> WorkspaceCostReport: validate_report_period(params.from_date, params.to_date) try: diff --git a/api_app/api/routes/migrations.py b/api_app/api/routes/migrations.py index 48fe49437f..fcbd32e22b 100644 --- a/api_app/api/routes/migrations.py +++ b/api_app/api/routes/migrations.py @@ -5,7 +5,7 @@ from db.repositories.resources_history import ResourceHistoryRepository from services.authentication import get_current_admin_user from resources import strings -from api.dependencies.database import get_repository +from api.dependencies.database import Database from db.migrations.shared_services import SharedServiceMigration from db.migrations.workspaces import WorkspaceMigration from db.repositories.resources import ResourceRepository @@ -20,13 +20,13 @@ name=strings.API_MIGRATE_DATABASE, response_model=MigrationOutList, dependencies=[Depends(get_current_admin_user)]) -async def migrate_database(resources_repo=Depends(get_repository(ResourceRepository)), - operations_repo=Depends(get_repository(OperationRepository)), - resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), - shared_services_migration=Depends(get_repository(SharedServiceMigration)), - workspace_migration=Depends(get_repository(WorkspaceMigration)), - resource_migration=Depends(get_repository(ResourceMigration)), - airlock_migration=Depends(get_repository(AirlockMigration)),): +async def migrate_database(resources_repo=Depends(Database().get_repository(ResourceRepository)), + operations_repo=Depends(Database().get_repository(OperationRepository)), + resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), + shared_services_migration=Depends(Database().get_repository(SharedServiceMigration)), + workspace_migration=Depends(Database().get_repository(WorkspaceMigration)), + resource_migration=Depends(Database().get_repository(ResourceMigration)), + airlock_migration=Depends(Database().get_repository(AirlockMigration)),): try: migrations = list() logger.info("PR 1030") diff --git a/api_app/api/routes/operations.py b/api_app/api/routes/operations.py index d5a707bebf..ee3c693a5b 100644 --- a/api_app/api/routes/operations.py +++ b/api_app/api/routes/operations.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends from db.repositories.operations import OperationRepository -from api.dependencies.database import get_repository +from api.dependencies.database import Database from models.schemas.operation import OperationInList from resources import strings from services.authentication import get_current_tre_user_or_tre_admin @@ -11,6 +11,6 @@ @operations_router.get("/operations", response_model=OperationInList, name=strings.API_GET_MY_OPERATIONS) -async def get_my_operations(user=Depends(get_current_tre_user_or_tre_admin), operations_repo=Depends(get_repository(OperationRepository))) -> OperationInList: +async def get_my_operations(user=Depends(get_current_tre_user_or_tre_admin), operations_repo=Depends(Database().get_repository(OperationRepository))) -> OperationInList: operations = await operations_repo.get_my_operations(user_id=user.id) return OperationInList(operations=operations) diff --git a/api_app/api/routes/shared_service_templates.py b/api_app/api/routes/shared_service_templates.py index fee2369ae6..a2c32f5aa6 100644 --- a/api_app/api/routes/shared_service_templates.py +++ b/api_app/api/routes/shared_service_templates.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as -from api.dependencies.database import get_repository +from api.dependencies.database import Database from db.errors import EntityDoesNotExist, EntityVersionExist, InvalidInput from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType @@ -17,13 +17,13 @@ @shared_service_templates_core_router.get("/shared-service-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_SHARED_SERVICE_TEMPLATES) -async def get_shared_service_templates(authorized_only: bool = False, template_repo=Depends(get_repository(ResourceTemplateRepository)), user=Depends(get_current_tre_user_or_tre_admin)) -> ResourceTemplateInformationInList: +async def get_shared_service_templates(authorized_only: bool = False, template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), user=Depends(get_current_tre_user_or_tre_admin)) -> ResourceTemplateInformationInList: templates_infos = await template_repo.get_templates_information(ResourceType.SharedService, user.roles if authorized_only else None) return ResourceTemplateInformationInList(templates=templates_infos) @shared_service_templates_core_router.get("/shared-service-templates/{shared_service_template_name}", response_model=SharedServiceTemplateInResponse, response_model_exclude_none=True, name=strings.API_GET_SHARED_SERVICE_TEMPLATE_BY_NAME, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_shared_service_template(shared_service_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> SharedServiceTemplateInResponse: +async def get_shared_service_template(shared_service_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> SharedServiceTemplateInResponse: try: template = await get_template(shared_service_template_name, template_repo, ResourceType.SharedService, is_update=is_update, version=version) return parse_obj_as(SharedServiceTemplateInResponse, template) @@ -32,7 +32,7 @@ async def get_shared_service_template(shared_service_template_name: str, is_upda @shared_service_templates_core_router.post("/shared-service-templates", status_code=status.HTTP_201_CREATED, response_model=SharedServiceTemplateInResponse, response_model_exclude_none=True, name=strings.API_CREATE_SHARED_SERVICE_TEMPLATES, dependencies=[Depends(get_current_admin_user)]) -async def register_shared_service_template(template_input: SharedServiceTemplateInCreate, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> ResourceTemplateInResponse: +async def register_shared_service_template(template_input: SharedServiceTemplateInCreate, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> ResourceTemplateInResponse: try: return await template_repo.create_and_validate_template(template_input, ResourceType.SharedService) except EntityVersionExist: diff --git a/api_app/api/routes/shared_services.py b/api_app/api/routes/shared_services.py index f09c3f996f..baee1bbc02 100644 --- a/api_app/api/routes/shared_services.py +++ b/api_app/api/routes/shared_services.py @@ -5,7 +5,7 @@ from db.repositories.operations import OperationRepository from db.errors import DuplicateEntity, MajorVersionUpdateDenied, UserNotAuthorizedToUseTemplate, TargetTemplateVersionDoesNotExist, VersionDowngradeDenied -from api.dependencies.database import get_repository +from api.dependencies.database import Database from api.dependencies.shared_services import get_shared_service_by_id_from_path, get_operation_by_id_from_path from db.repositories.resource_templates import ResourceTemplateRepository from db.repositories.resources_history import ResourceHistoryRepository @@ -33,7 +33,7 @@ def user_is_tre_admin(user): @shared_services_router.get("/shared-services", response_model=SharedServicesInList, name=strings.API_GET_ALL_SHARED_SERVICES, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def retrieve_shared_services(shared_services_repo=Depends(get_repository(SharedServiceRepository)), user=Depends(get_current_tre_user_or_tre_admin), resource_template_repo=Depends(get_repository(ResourceTemplateRepository))) -> SharedServicesInList: +async def retrieve_shared_services(shared_services_repo=Depends(Database().get_repository(SharedServiceRepository)), user=Depends(get_current_tre_user_or_tre_admin), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> SharedServicesInList: shared_services = await shared_services_repo.get_active_shared_services() await asyncio.gather(*[enrich_resource_with_available_upgrades(shared_service, resource_template_repo) for shared_service in shared_services]) if user_is_tre_admin(user): @@ -43,7 +43,7 @@ async def retrieve_shared_services(shared_services_repo=Depends(get_repository(S @shared_services_router.get("/shared-services/{shared_service_id}", response_model=SharedServiceInResponse, name=strings.API_GET_SHARED_SERVICE_BY_ID, dependencies=[Depends(get_current_tre_user_or_tre_admin), Depends(get_shared_service_by_id_from_path)]) -async def retrieve_shared_service_by_id(shared_service=Depends(get_shared_service_by_id_from_path), user=Depends(get_current_tre_user_or_tre_admin), resource_template_repo=Depends(get_repository(ResourceTemplateRepository))): +async def retrieve_shared_service_by_id(shared_service=Depends(get_shared_service_by_id_from_path), user=Depends(get_current_tre_user_or_tre_admin), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository))): await enrich_resource_with_available_upgrades(shared_service, resource_template_repo) if user_is_tre_admin(user): return SharedServiceInResponse(sharedService=shared_service) @@ -52,7 +52,7 @@ async def retrieve_shared_service_by_id(shared_service=Depends(get_shared_servic @shared_services_router.post("/shared-services", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_CREATE_SHARED_SERVICE, dependencies=[Depends(get_current_admin_user)]) -async def create_shared_service(response: Response, shared_service_input: SharedServiceInCreate, user=Depends(get_current_admin_user), shared_services_repo=Depends(get_repository(SharedServiceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def create_shared_service(response: Response, shared_service_input: SharedServiceInCreate, user=Depends(get_current_admin_user), shared_services_repo=Depends(Database().get_repository(SharedServiceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: try: shared_service, resource_template = await shared_services_repo.create_shared_service_item(shared_service_input, user.roles) except (ValidationError, ValueError) as e: @@ -83,7 +83,7 @@ async def create_shared_service(response: Response, shared_service_input: Shared response_model=OperationInResponse, name=strings.API_UPDATE_SHARED_SERVICE, dependencies=[Depends(get_current_admin_user), Depends(get_shared_service_by_id_from_path)]) -async def patch_shared_service(shared_service_patch: ResourcePatch, response: Response, user=Depends(get_current_admin_user), shared_service_repo=Depends(get_repository(SharedServiceRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), shared_service=Depends(get_shared_service_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), etag: str = Header(...), force_version_update: bool = False) -> SharedServiceInResponse: +async def patch_shared_service(shared_service_patch: ResourcePatch, response: Response, user=Depends(get_current_admin_user), shared_service_repo=Depends(Database().get_repository(SharedServiceRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), shared_service=Depends(get_shared_service_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), etag: str = Header(...), force_version_update: bool = False) -> SharedServiceInResponse: try: patched_shared_service, _ = await shared_service_repo.patch_shared_service(shared_service, shared_service_patch, etag, resource_template_repo, resource_history_repo, user, force_version_update) operation = await send_resource_request_message( @@ -106,7 +106,7 @@ async def patch_shared_service(shared_service_patch: ResourcePatch, response: Re @shared_services_router.delete("/shared-services/{shared_service_id}", response_model=OperationInResponse, name=strings.API_DELETE_SHARED_SERVICE, dependencies=[Depends(get_current_admin_user)]) -async def delete_shared_service(response: Response, user=Depends(get_current_admin_user), shared_service=Depends(get_shared_service_by_id_from_path), operations_repo=Depends(get_repository(OperationRepository)), shared_service_repo=Depends(get_repository(SharedServiceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def delete_shared_service(response: Response, user=Depends(get_current_admin_user), shared_service=Depends(get_shared_service_by_id_from_path), operations_repo=Depends(Database().get_repository(OperationRepository)), shared_service_repo=Depends(Database().get_repository(SharedServiceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: if shared_service.isEnabled: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=strings.SHARED_SERVICE_NEEDS_TO_BE_DISABLED_BEFORE_DELETION) @@ -125,7 +125,7 @@ async def delete_shared_service(response: Response, user=Depends(get_current_adm @shared_services_router.post("/shared-services/{shared_service_id}/invoke-action", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_INVOKE_ACTION_ON_SHARED_SERVICE, dependencies=[Depends(get_current_admin_user)]) -async def invoke_action_on_shared_service(response: Response, action: str, user=Depends(get_current_admin_user), shared_service=Depends(get_shared_service_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), shared_service_repo=Depends(get_repository(SharedServiceRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def invoke_action_on_shared_service(response: Response, action: str, user=Depends(get_current_admin_user), shared_service=Depends(get_shared_service_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), shared_service_repo=Depends(Database().get_repository(SharedServiceRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: operation = await send_custom_action_message( resource=shared_service, resource_repo=shared_service_repo, @@ -143,7 +143,7 @@ async def invoke_action_on_shared_service(response: Response, action: str, user= # Shared service operations @shared_services_router.get("/shared-services/{shared_service_id}/operations", response_model=OperationInList, name=strings.API_GET_RESOURCE_OPERATIONS, dependencies=[Depends(get_current_admin_user), Depends(get_shared_service_by_id_from_path)]) -async def retrieve_shared_service_operations_by_shared_service_id(shared_service=Depends(get_shared_service_by_id_from_path), operations_repo=Depends(get_repository(OperationRepository))) -> OperationInList: +async def retrieve_shared_service_operations_by_shared_service_id(shared_service=Depends(get_shared_service_by_id_from_path), operations_repo=Depends(Database().get_repository(OperationRepository))) -> OperationInList: return OperationInList(operations=await operations_repo.get_operations_by_resource_id(resource_id=shared_service.id)) @@ -154,5 +154,5 @@ async def retrieve_shared_service_operation_by_shared_service_id_and_operation_i # Shared service history @shared_services_router.get("/shared-services/{shared_service_id}/history", response_model=ResourceHistoryInList, name=strings.API_GET_RESOURCE_HISTORY, dependencies=[Depends(get_current_admin_user)]) -async def retrieve_shared_service_history_by_shared_service_id(shared_service=Depends(get_shared_service_by_id_from_path), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: +async def retrieve_shared_service_history_by_shared_service_id(shared_service=Depends(get_shared_service_by_id_from_path), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: return ResourceHistoryInList(resource_history=await resource_history_repo.get_resource_history_by_resource_id(resource_id=shared_service.id)) diff --git a/api_app/api/routes/user_resource_templates.py b/api_app/api/routes/user_resource_templates.py index b4cc9f9b6b..1f0d860652 100644 --- a/api_app/api/routes/user_resource_templates.py +++ b/api_app/api/routes/user_resource_templates.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as -from api.dependencies.database import get_repository +from api.dependencies.database import Database from api.dependencies.workspace_service_templates import get_workspace_service_template_by_name_from_path from api.routes.resource_helpers import get_template from db.errors import EntityVersionExist, InvalidInput @@ -19,19 +19,19 @@ @user_resource_templates_core_router.get("/workspace-service-templates/{service_template_name}/user-resource-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_USER_RESOURCE_TEMPLATES, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_user_resource_templates_for_service_template(service_template_name: str, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> ResourceTemplateInformationInList: +async def get_user_resource_templates_for_service_template(service_template_name: str, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> ResourceTemplateInformationInList: template_infos = await template_repo.get_templates_information(ResourceType.UserResource, parent_service_name=service_template_name) return ResourceTemplateInformationInList(templates=template_infos) @user_resource_templates_core_router.get("/workspace-service-templates/{service_template_name}/user-resource-templates/{user_resource_template_name}", response_model=UserResourceTemplateInResponse, response_model_exclude_none=True, name=strings.API_GET_USER_RESOURCE_TEMPLATE_BY_NAME, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_user_resource_template(service_template_name: str, user_resource_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> UserResourceTemplateInResponse: +async def get_user_resource_template(service_template_name: str, user_resource_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> UserResourceTemplateInResponse: template = await get_template(user_resource_template_name, template_repo, ResourceType.UserResource, service_template_name, is_update=is_update, version=version) return parse_obj_as(UserResourceTemplateInResponse, template) @user_resource_templates_core_router.post("/workspace-service-templates/{service_template_name}/user-resource-templates", status_code=status.HTTP_201_CREATED, response_model=UserResourceTemplateInResponse, response_model_exclude_none=True, name=strings.API_CREATE_USER_RESOURCE_TEMPLATES, dependencies=[Depends(get_current_admin_user)]) -async def register_user_resource_template(template_input: UserResourceTemplateInCreate, template_repo=Depends(get_repository(ResourceTemplateRepository)), workspace_service_template=Depends(get_workspace_service_template_by_name_from_path)) -> UserResourceTemplateInResponse: +async def register_user_resource_template(template_input: UserResourceTemplateInCreate, template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), workspace_service_template=Depends(get_workspace_service_template_by_name_from_path)) -> UserResourceTemplateInResponse: try: return await template_repo.create_and_validate_template(template_input, ResourceType.UserResource, workspace_service_template.name) except EntityVersionExist: diff --git a/api_app/api/routes/workspace_service_templates.py b/api_app/api/routes/workspace_service_templates.py index c04558ad81..2db35543c5 100644 --- a/api_app/api/routes/workspace_service_templates.py +++ b/api_app/api/routes/workspace_service_templates.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as -from api.dependencies.database import get_repository +from api.dependencies.database import Database from api.routes.resource_helpers import get_template from db.errors import EntityVersionExist, InvalidInput from db.repositories.resource_templates import ResourceTemplateRepository @@ -17,19 +17,19 @@ @workspace_service_templates_core_router.get("/workspace-service-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_WORKSPACE_SERVICE_TEMPLATES, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_workspace_service_templates(template_repo=Depends(get_repository(ResourceTemplateRepository))) -> ResourceTemplateInformationInList: +async def get_workspace_service_templates(template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> ResourceTemplateInformationInList: templates_infos = await template_repo.get_templates_information(ResourceType.WorkspaceService) return ResourceTemplateInformationInList(templates=templates_infos) @workspace_service_templates_core_router.get("/workspace-service-templates/{service_template_name}", response_model=WorkspaceServiceTemplateInResponse, response_model_exclude_none=True, name=strings.API_GET_WORKSPACE_SERVICE_TEMPLATE_BY_NAME, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_workspace_service_template(service_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> WorkspaceServiceTemplateInResponse: +async def get_workspace_service_template(service_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> WorkspaceServiceTemplateInResponse: template = await get_template(service_template_name, template_repo, ResourceType.WorkspaceService, is_update=is_update, version=version) return parse_obj_as(WorkspaceServiceTemplateInResponse, template) @workspace_service_templates_core_router.post("/workspace-service-templates", status_code=status.HTTP_201_CREATED, response_model=WorkspaceServiceTemplateInResponse, response_model_exclude_none=True, name=strings.API_CREATE_WORKSPACE_SERVICE_TEMPLATES, dependencies=[Depends(get_current_admin_user)]) -async def register_workspace_service_template(template_input: WorkspaceServiceTemplateInCreate, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> ResourceTemplateInResponse: +async def register_workspace_service_template(template_input: WorkspaceServiceTemplateInCreate, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> ResourceTemplateInResponse: try: return await template_repo.create_and_validate_template(template_input, ResourceType.WorkspaceService) except EntityVersionExist: diff --git a/api_app/api/routes/workspace_templates.py b/api_app/api/routes/workspace_templates.py index 6ba864724d..56d5c3153b 100644 --- a/api_app/api/routes/workspace_templates.py +++ b/api_app/api/routes/workspace_templates.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as -from api.dependencies.database import get_repository +from api.dependencies.database import Database from db.errors import EntityVersionExist, InvalidInput from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType @@ -17,19 +17,19 @@ @workspace_templates_admin_router.get("/workspace-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_WORKSPACE_TEMPLATES) -async def get_workspace_templates(authorized_only: bool = False, template_repo=Depends(get_repository(ResourceTemplateRepository)), user=Depends(get_current_admin_user)) -> ResourceTemplateInformationInList: +async def get_workspace_templates(authorized_only: bool = False, template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), user=Depends(get_current_admin_user)) -> ResourceTemplateInformationInList: templates_infos = await template_repo.get_templates_information(ResourceType.Workspace, user.roles if authorized_only else None) return ResourceTemplateInformationInList(templates=templates_infos) @workspace_templates_admin_router.get("/workspace-templates/{workspace_template_name}", response_model=WorkspaceTemplateInResponse, name=strings.API_GET_WORKSPACE_TEMPLATE_BY_NAME, response_model_exclude_none=True) -async def get_workspace_template(workspace_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> WorkspaceTemplateInResponse: +async def get_workspace_template(workspace_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> WorkspaceTemplateInResponse: template = await get_template(workspace_template_name, template_repo, ResourceType.Workspace, is_update=is_update, version=version) return parse_obj_as(WorkspaceTemplateInResponse, template) @workspace_templates_admin_router.post("/workspace-templates", status_code=status.HTTP_201_CREATED, response_model=WorkspaceTemplateInResponse, response_model_exclude_none=True, name=strings.API_CREATE_WORKSPACE_TEMPLATES) -async def register_workspace_template(template_input: WorkspaceTemplateInCreate, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> ResourceTemplateInResponse: +async def register_workspace_template(template_input: WorkspaceTemplateInCreate, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> ResourceTemplateInResponse: try: return await template_repo.create_and_validate_template(template_input, ResourceType.Workspace) except EntityVersionExist: diff --git a/api_app/api/routes/workspaces.py b/api_app/api/routes/workspaces.py index a086812ff5..a0a7da34eb 100644 --- a/api_app/api/routes/workspaces.py +++ b/api_app/api/routes/workspaces.py @@ -4,7 +4,7 @@ from jsonschema.exceptions import ValidationError -from api.dependencies.database import get_repository +from api.dependencies.database import Database from api.dependencies.workspaces import get_operation_by_id_from_path, get_workspace_by_id_from_path, get_deployed_workspace_by_id_from_path, get_deployed_workspace_service_by_id_from_path, get_workspace_service_by_id_from_path, get_user_resource_by_id_from_path from db.errors import InvalidInput, MajorVersionUpdateDenied, TargetTemplateVersionDoesNotExist, UserNotAuthorizedToUseTemplate, VersionDowngradeDenied from db.repositories.operations import OperationRepository @@ -56,7 +56,7 @@ def validate_user_has_valid_role_for_user_resource(user, user_resource): # WORKSPACE ROUTES @workspaces_core_router.get("/workspaces", response_model=WorkspacesInList, name=strings.API_GET_ALL_WORKSPACES) -async def retrieve_users_active_workspaces(request: Request, user=Depends(get_current_tre_user_or_tre_admin), workspace_repo=Depends(get_repository(WorkspaceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository))) -> WorkspacesInList: +async def retrieve_users_active_workspaces(request: Request, user=Depends(get_current_tre_user_or_tre_admin), workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> WorkspacesInList: try: user = await get_current_admin_user(request) @@ -83,7 +83,7 @@ def _safe_get_workspace_role(user, workspace, user_role_assignments): @workspaces_shared_router.get("/workspaces/{workspace_id}", response_model=WorkspaceInResponse, name=strings.API_GET_WORKSPACE_BY_ID) -async def retrieve_workspace_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository))) -> WorkspaceInResponse: +async def retrieve_workspace_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> WorkspaceInResponse: await enrich_resource_with_available_upgrades(workspace, resource_template_repo) return WorkspaceInResponse(workspace=workspace) @@ -97,7 +97,7 @@ async def retrieve_workspace_scope_id_by_workspace_id(workspace=Depends(get_work @workspaces_core_router.post("/workspaces", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_CREATE_WORKSPACE, dependencies=[Depends(get_current_admin_user)]) -async def create_workspace(workspace_create: WorkspaceInCreate, response: Response, user=Depends(get_current_admin_user), workspace_repo=Depends(get_repository(WorkspaceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def create_workspace(workspace_create: WorkspaceInCreate, response: Response, user=Depends(get_current_admin_user), workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: try: # TODO: This requires Directory.ReadAll ( Application.Read.All ) to be enabled in the Azure AD application to enable a users workspaces to be listed. This should be made optional. auth_info = extract_auth_information(workspace_create.properties) @@ -125,7 +125,7 @@ async def create_workspace(workspace_create: WorkspaceInCreate, response: Respon @workspaces_core_router.patch("/workspaces/{workspace_id}", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_UPDATE_WORKSPACE, dependencies=[Depends(get_current_admin_user)]) -async def patch_workspace(resource_patch: ResourcePatch, response: Response, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), workspace_repo: WorkspaceRepository = Depends(get_repository(WorkspaceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: +async def patch_workspace(resource_patch: ResourcePatch, response: Response, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), workspace_repo: WorkspaceRepository = Depends(Database().get_repository(WorkspaceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: try: is_disablement = resource_patch.isEnabled is not None and not resource_patch.isEnabled if is_disablement: @@ -153,7 +153,7 @@ async def patch_workspace(resource_patch: ResourcePatch, response: Response, use @workspaces_core_router.delete("/workspaces/{workspace_id}", response_model=OperationInResponse, name=strings.API_DELETE_WORKSPACE, dependencies=[Depends(get_current_admin_user)]) -async def delete_workspace(response: Response, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), operations_repo=Depends(get_repository(OperationRepository)), workspace_repo=Depends(get_repository(WorkspaceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def delete_workspace(response: Response, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), operations_repo=Depends(Database().get_repository(OperationRepository)), workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: if await delete_validation(workspace, workspace_repo): operation = await send_uninstall_message( resource=workspace, @@ -171,7 +171,7 @@ async def delete_workspace(response: Response, user=Depends(get_current_admin_us @workspaces_core_router.post("/workspaces/{workspace_id}/invoke-action", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_INVOKE_ACTION_ON_WORKSPACE, dependencies=[Depends(get_current_admin_user)]) -async def invoke_action_on_workspace(response: Response, action: str, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), workspace_repo=Depends(get_repository(WorkspaceRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def invoke_action_on_workspace(response: Response, action: str, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: operation = await send_custom_action_message( resource=workspace, resource_repo=workspace_repo, @@ -192,7 +192,7 @@ async def invoke_action_on_workspace(response: Response, action: str, user=Depen @workspaces_shared_router.get("/workspaces/{workspace_id}/workspace-service-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_WORKSPACE_SERVICE_TEMPLATES_IN_WORKSPACE) async def get_workspace_service_templates( workspace=Depends(get_workspace_by_id_from_path), - template_repo=Depends(get_repository(ResourceTemplateRepository)), + template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager_or_tre_admin)) -> ResourceTemplateInformationInList: template_infos = await template_repo.get_templates_information(ResourceType.WorkspaceService, user.roles) return ResourceTemplateInformationInList(templates=template_infos) @@ -203,14 +203,14 @@ async def get_workspace_service_templates( async def get_user_resource_templates( service_template_name: str, workspace=Depends(get_workspace_by_id_from_path), - template_repo=Depends(get_repository(ResourceTemplateRepository)), + template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager_or_tre_admin)) -> ResourceTemplateInformationInList: template_infos = await template_repo.get_templates_information(ResourceType.UserResource, user.roles, service_template_name) return ResourceTemplateInformationInList(templates=template_infos) @workspaces_shared_router.get("/workspaces/{workspace_id}/operations", response_model=OperationInList, name=strings.API_GET_RESOURCE_OPERATIONS, dependencies=[Depends(get_current_workspace_owner_or_tre_admin)]) -async def retrieve_workspace_operations_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), operations_repo=Depends(get_repository(OperationRepository))) -> OperationInList: +async def retrieve_workspace_operations_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), operations_repo=Depends(Database().get_repository(OperationRepository))) -> OperationInList: return OperationInList(operations=await operations_repo.get_operations_by_resource_id(resource_id=workspace.id)) @@ -220,26 +220,26 @@ async def retrieve_workspace_operation_by_workspace_id_and_operation_id(workspac @workspaces_shared_router.get("/workspaces/{workspace_id}/history", response_model=ResourceHistoryInList, name=strings.API_GET_RESOURCE_HISTORY, dependencies=[Depends(get_current_workspace_owner_or_tre_admin)]) -async def retrieve_workspace_history_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: +async def retrieve_workspace_history_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: return ResourceHistoryInList(resource_history=await resource_history_repo.get_resource_history_by_resource_id(resource_id=workspace.id)) # WORKSPACE SERVICES ROUTES @workspace_services_workspace_router.get("/workspaces/{workspace_id}/workspace-services", response_model=WorkspaceServicesInList, name=strings.API_GET_ALL_WORKSPACE_SERVICES, dependencies=[Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager)]) -async def retrieve_users_active_workspace_services(workspace=Depends(get_workspace_by_id_from_path), workspace_services_repo=Depends(get_repository(WorkspaceServiceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository))) -> WorkspaceServicesInList: +async def retrieve_users_active_workspace_services(workspace=Depends(get_workspace_by_id_from_path), workspace_services_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> WorkspaceServicesInList: workspace_services = await workspace_services_repo.get_active_workspace_services_for_workspace(workspace.id) await asyncio.gather(*[enrich_resource_with_available_upgrades(workspace_service, resource_template_repo) for workspace_service in workspace_services]) return WorkspaceServicesInList(workspaceServices=workspace_services) @workspace_services_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}", response_model=WorkspaceServiceInResponse, name=strings.API_GET_WORKSPACE_SERVICE_BY_ID, dependencies=[Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) -async def retrieve_workspace_service_by_id(workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository))) -> WorkspaceServiceInResponse: +async def retrieve_workspace_service_by_id(workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> WorkspaceServiceInResponse: await enrich_resource_with_available_upgrades(workspace_service, resource_template_repo) return WorkspaceServiceInResponse(workspaceService=workspace_service) @workspace_services_workspace_router.post("/workspaces/{workspace_id}/workspace-services", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_CREATE_WORKSPACE_SERVICE, dependencies=[Depends(get_current_workspace_owner_user)]) -async def create_workspace_service(response: Response, workspace_service_input: WorkspaceServiceInCreate, user=Depends(get_current_workspace_owner_user), workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), workspace_repo=Depends(get_repository(WorkspaceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), workspace=Depends(get_deployed_workspace_by_id_from_path)) -> OperationInResponse: +async def create_workspace_service(response: Response, workspace_service_input: WorkspaceServiceInCreate, user=Depends(get_current_workspace_owner_user), workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), workspace=Depends(get_deployed_workspace_by_id_from_path)) -> OperationInResponse: try: workspace_service, resource_template = await workspace_service_repo.create_workspace_service_item(workspace_service_input, workspace.id, user.roles) @@ -280,7 +280,7 @@ async def create_workspace_service(response: Response, workspace_service_input: @workspace_services_workspace_router.patch("/workspaces/{workspace_id}/workspace-services/{service_id}", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_UPDATE_WORKSPACE_SERVICE, dependencies=[Depends(get_current_workspace_owner_or_researcher_user), Depends(get_workspace_by_id_from_path)]) -async def patch_workspace_service(resource_patch: ResourcePatch, response: Response, user=Depends(get_current_workspace_owner_user), workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: +async def patch_workspace_service(resource_patch: ResourcePatch, response: Response, user=Depends(get_current_workspace_owner_user), workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: try: is_disablement = resource_patch.isEnabled is not None and not resource_patch.isEnabled if is_disablement: @@ -306,7 +306,7 @@ async def patch_workspace_service(resource_patch: ResourcePatch, response: Respo @workspace_services_workspace_router.delete("/workspaces/{workspace_id}/workspace-services/{service_id}", response_model=OperationInResponse, name=strings.API_DELETE_WORKSPACE_SERVICE, dependencies=[Depends(get_current_workspace_owner_user)]) -async def delete_workspace_service(response: Response, user=Depends(get_current_workspace_owner_user), workspace=Depends(get_workspace_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), user_resource_repo=Depends(get_repository(UserResourceRepository)), operations_repo=Depends(get_repository(OperationRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def delete_workspace_service(response: Response, user=Depends(get_current_workspace_owner_user), workspace=Depends(get_workspace_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: if await delete_validation(workspace_service, workspace_service_repo): operation = await send_uninstall_message( resource=workspace_service, @@ -324,7 +324,7 @@ async def delete_workspace_service(response: Response, user=Depends(get_current_ @workspace_services_workspace_router.post("/workspaces/{workspace_id}/workspace-services/{service_id}/invoke-action", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_INVOKE_ACTION_ON_WORKSPACE_SERVICE, dependencies=[Depends(get_current_workspace_owner_user)]) -async def invoke_action_on_workspace_service(response: Response, action: str, user=Depends(get_current_workspace_owner_user), workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def invoke_action_on_workspace_service(response: Response, action: str, user=Depends(get_current_workspace_owner_user), workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: operation = await send_custom_action_message( resource=workspace_service, resource_repo=workspace_service_repo, @@ -342,7 +342,7 @@ async def invoke_action_on_workspace_service(response: Response, action: str, us # workspace service operations @workspace_services_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}/operations", response_model=OperationInList, name=strings.API_GET_RESOURCE_OPERATIONS, dependencies=[Depends(get_current_workspace_owner_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) -async def retrieve_workspace_service_operations_by_workspace_service_id(workspace_service=Depends(get_workspace_service_by_id_from_path), operations_repo=Depends(get_repository(OperationRepository))) -> OperationInList: +async def retrieve_workspace_service_operations_by_workspace_service_id(workspace_service=Depends(get_workspace_service_by_id_from_path), operations_repo=Depends(Database().get_repository(OperationRepository))) -> OperationInList: return OperationInList(operations=await operations_repo.get_operations_by_resource_id(resource_id=workspace_service.id)) @@ -352,7 +352,7 @@ async def retrieve_workspace_service_operation_by_workspace_service_id_and_opera @workspace_services_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}/history", response_model=ResourceHistoryInList, name=strings.API_GET_RESOURCE_HISTORY, dependencies=[Depends(get_current_workspace_owner_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) -async def retrieve_workspace_service_history_by_workspace_service_id(workspace_service=Depends(get_workspace_service_by_id_from_path), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: +async def retrieve_workspace_service_history_by_workspace_service_id(workspace_service=Depends(get_workspace_service_by_id_from_path), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: return ResourceHistoryInList(resource_history=await resource_history_repo.get_resource_history_by_resource_id(resource_id=workspace_service.id)) @@ -362,8 +362,8 @@ async def retrieve_user_resources_for_workspace_service( workspace_id: str, service_id: str, user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), - resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), - user_resource_repo=Depends(get_repository(UserResourceRepository))) -> UserResourcesInList: + resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), + user_resource_repo=Depends(Database().get_repository(UserResourceRepository))) -> UserResourcesInList: user_resources = await user_resource_repo.get_user_resources_for_workspace_service(workspace_id, service_id) # filter only to the user - for researchers @@ -382,7 +382,7 @@ async def retrieve_user_resources_for_workspace_service( @user_resources_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}/user-resources/{resource_id}", response_model=UserResourceInResponse, name=strings.API_GET_USER_RESOURCE, dependencies=[Depends(get_workspace_by_id_from_path)]) async def retrieve_user_resource_by_id( user_resource=Depends(get_user_resource_by_id_from_path), - resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), + resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager)) -> UserResourceInResponse: validate_user_has_valid_role_for_user_resource(user, user_resource) @@ -397,10 +397,10 @@ async def retrieve_user_resource_by_id( async def create_user_resource( response: Response, user_resource_create: UserResourceInCreate, - user_resource_repo=Depends(get_repository(UserResourceRepository)), - resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), - operations_repo=Depends(get_repository(OperationRepository)), - resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), + user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), + resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), + operations_repo=Depends(Database().get_repository(OperationRepository)), + resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), workspace=Depends(get_deployed_workspace_by_id_from_path), workspace_service=Depends(get_deployed_workspace_service_by_id_from_path)) -> OperationInResponse: @@ -433,10 +433,10 @@ async def delete_user_resource( user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), user_resource=Depends(get_user_resource_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), - user_resource_repo=Depends(get_repository(UserResourceRepository)), - operations_repo=Depends(get_repository(OperationRepository)), - resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), - resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: + user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), + operations_repo=Depends(Database().get_repository(OperationRepository)), + resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), + resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: validate_user_has_valid_role_for_user_resource(user, user_resource) if user_resource.isEnabled: @@ -463,10 +463,10 @@ async def patch_user_resource( user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), user_resource=Depends(get_user_resource_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), - user_resource_repo=Depends(get_repository(UserResourceRepository)), - resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), - resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), - operations_repo=Depends(get_repository(OperationRepository)), + user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), + resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), + resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), + operations_repo=Depends(Database().get_repository(OperationRepository)), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: validate_user_has_valid_role_for_user_resource(user, user_resource) @@ -490,10 +490,10 @@ async def invoke_action_on_user_resource( action: str, user_resource=Depends(get_user_resource_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), - resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), - user_resource_repo=Depends(get_repository(UserResourceRepository)), - operations_repo=Depends(get_repository(OperationRepository)), - resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), + resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), + user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), + operations_repo=Depends(Database().get_repository(OperationRepository)), + resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager)) -> OperationInResponse: validate_user_has_valid_role_for_user_resource(user, user_resource) operation = await send_custom_action_message( @@ -517,7 +517,7 @@ async def invoke_action_on_user_resource( async def retrieve_user_resource_operations_by_user_resource_id( user_resource=Depends(get_user_resource_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), - operations_repo=Depends(get_repository(OperationRepository))) -> OperationInList: + operations_repo=Depends(Database().get_repository(OperationRepository))) -> OperationInList: validate_user_has_valid_role_for_user_resource(user, user_resource) return OperationInList(operations=await operations_repo.get_operations_by_resource_id(resource_id=user_resource.id)) @@ -532,6 +532,6 @@ async def retrieve_user_resource_operations_by_user_resource_id_and_operation_id @user_resources_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}/user-resources/{resource_id}/history", response_model=ResourceHistoryInList, name=strings.API_GET_RESOURCE_HISTORY, dependencies=[Depends(get_workspace_by_id_from_path)]) -async def retrieve_user_resource_history_by_user_resource_id(user_resource=Depends(get_user_resource_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: +async def retrieve_user_resource_history_by_user_resource_id(user_resource=Depends(get_user_resource_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: validate_user_has_valid_role_for_user_resource(user, user_resource) return ResourceHistoryInList(resource_history=await resource_history_repo.get_resource_history_by_resource_id(resource_id=user_resource.id)) diff --git a/api_app/main.py b/api_app/main.py index b7e7fa8aaf..8d88e9f6e8 100644 --- a/api_app/main.py +++ b/api_app/main.py @@ -30,10 +30,10 @@ async def lifespan(app: FastAPI): await asyncio.sleep(5) logger.warning("Database connection could not be established") - deploymentStatusUpdater = DeploymentStatusUpdater(app) + deploymentStatusUpdater = DeploymentStatusUpdater() await deploymentStatusUpdater.init_repos() - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() asyncio.create_task(deploymentStatusUpdater.receive_messages()) diff --git a/api_app/service_bus/airlock_request_status_update.py b/api_app/service_bus/airlock_request_status_update.py index e26f573303..babb302702 100644 --- a/api_app/service_bus/airlock_request_status_update.py +++ b/api_app/service_bus/airlock_request_status_update.py @@ -6,7 +6,7 @@ from fastapi import HTTPException from pydantic import ValidationError, parse_obj_as -from api.dependencies.database import get_db_client +from api.dependencies.database import Database from api.dependencies.airlock import get_airlock_request_by_id_from_path from services.airlock import update_and_publish_event_airlock_request from services.logging import logger, tracer @@ -20,11 +20,11 @@ class AirlockStatusUpdater(): - def __init__(self, app): - self.app = app + def __init__(self): + pass async def init_repos(self): - db_client = await get_db_client(self.app) + db_client = await Database().get_db_client() self.airlock_request_repo = await AirlockRequestRepository.create(db_client) self.workspace_repo = await WorkspaceRepository.create(db_client) diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index b38e138927..255b25d1c3 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -4,7 +4,7 @@ from pydantic import ValidationError, parse_obj_as -from api.dependencies.database import get_db_client +from api.dependencies.database import Database from api.routes.resource_helpers import get_timestamp from models.domain.resource import Output from db.repositories.resources_history import ResourceHistoryRepository @@ -24,11 +24,11 @@ class DeploymentStatusUpdater(): - def __init__(self, app): - self.app = app + def __init__(self): + pass async def init_repos(self): - db_client = await get_db_client(self.app) + db_client = await Database().get_db_client() self.operations_repo = await OperationRepository.create(db_client) self.resource_repo = await ResourceRepository.create(db_client) self.resource_template_repo = await ResourceTemplateRepository.create(db_client) diff --git a/api_app/services/aad_authentication.py b/api_app/services/aad_authentication.py index fba97619ed..8842a0c607 100644 --- a/api_app/services/aad_authentication.py +++ b/api_app/services/aad_authentication.py @@ -14,7 +14,7 @@ from models.domain.authentication import User, RoleAssignment from models.domain.workspace import Workspace, WorkspaceRole from resources import strings -from api.dependencies.database import get_db_client_from_request +from api.dependencies.database import Database from db.repositories.workspaces import WorkspaceRepository from services.logging import logger @@ -120,7 +120,8 @@ async def _fetch_ws_app_reg_id_from_ws_id(request: Request) -> str: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=strings.AUTH_COULD_NOT_VALIDATE_CREDENTIALS) try: workspace_id = request.path_params['workspace_id'] - db_client = await get_db_client_from_request(request) + + db_client = Database().get_db_client() ws_repo = await WorkspaceRepository.create(db_client) workspace = await ws_repo.get_workspace_by_id(workspace_id) diff --git a/api_app/services/health_checker.py b/api_app/services/health_checker.py index aa8a5013a2..6314fa9a79 100644 --- a/api_app/services/health_checker.py +++ b/api_app/services/health_checker.py @@ -4,7 +4,7 @@ from azure.mgmt.compute.aio import ComputeManagementClient from azure.cosmos.exceptions import CosmosHttpResponseError from azure.servicebus.exceptions import ServiceBusConnectionError, ServiceBusAuthenticationError -from api.dependencies.database import get_db_client_from_request +from api.dependencies.database import Database from core import config from models.schemas.status import StatusEnum @@ -16,7 +16,7 @@ async def create_state_store_status(request) -> Tuple[StatusEnum, str]: status = StatusEnum.ok message = "" try: - cosmos_client = await get_db_client_from_request(request) + cosmos_client = await Database().get_db_client() async with cosmos_client: list_databases_response = cosmos_client.list_databases() [database async for database in list_databases_response] diff --git a/api_app/tests_ma/conftest.py b/api_app/tests_ma/conftest.py index 1afbaf3819..ccbb44ebfc 100644 --- a/api_app/tests_ma/conftest.py +++ b/api_app/tests_ma/conftest.py @@ -575,8 +575,8 @@ def simple_pipeline_step() -> PipelineStep: @pytest_asyncio.fixture() def no_database(): """overrides connecting to the database""" - with patch("api.dependencies.database.connect_to_db", return_value=None): - with patch("api.dependencies.database.get_db_client", return_value=None): + with patch("api.dependencies.database.Database._connect_to_db", return_value=None): + with patch("api.dependencies.database.Database.get_db_client", return_value=None): with patch( "db.repositories.base.BaseRepository._get_container", return_value=None ): diff --git a/api_app/tests_ma/test_api/conftest.py b/api_app/tests_ma/test_api/conftest.py index a247b91ea5..a22a6080e0 100644 --- a/api_app/tests_ma/test_api/conftest.py +++ b/api_app/tests_ma/test_api/conftest.py @@ -17,8 +17,8 @@ def no_lifespan_events(): @pytest_asyncio.fixture(autouse=True) def no_database(): """ overrides connecting to the database for all tests""" - with patch('api.dependencies.database.connect_to_db', return_value=None): - with patch('api.dependencies.database.get_db_client', return_value=None): + with patch('api.dependencies.database.Database._connect_to_db', return_value=None): + with patch('api.dependencies.database.Database.get_db_client', return_value=None): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): with patch('db.events.bootstrap_database', return_value=None): yield diff --git a/api_app/tests_ma/test_api/test_routes/test_workspaces.py b/api_app/tests_ma/test_api/test_routes/test_workspaces.py index 577391fe7f..feed3f2bc2 100644 --- a/api_app/tests_ma/test_api/test_routes/test_workspaces.py +++ b/api_app/tests_ma/test_api/test_routes/test_workspaces.py @@ -676,7 +676,7 @@ async def test_delete_workspace_returns_400_if_associated_workspace_services_are # [DELETE] /workspaces/{workspace_id} @ patch("api.routes.resource_helpers.ResourceRepository.get_resource_dependency_list") @ patch("api.routes.workspaces.ResourceTemplateRepository.get_template_by_name_and_version") - @ patch("api.dependencies.workspaces.get_repository") + @ patch("api.dependencies.database.Database.get_repository") @ patch("api.dependencies.workspaces.WorkspaceRepository.get_workspace_by_id") @ patch("api.routes.workspaces.WorkspaceServiceRepository.get_active_workspace_services_for_workspace", return_value=[]) @ patch('azure.cosmos.CosmosClient') @@ -692,7 +692,7 @@ async def test_delete_workspace_sends_a_request_message_to_uninstall_the_workspa # [DELETE] /workspaces/{workspace_id} @ patch("api.routes.resource_helpers.ResourceRepository.get_resource_dependency_list") @ patch("api.routes.workspaces.ResourceTemplateRepository.get_template_by_name_and_version") - @ patch("api.dependencies.workspaces.get_repository") + @ patch("api.dependencies.database.Database.get_repository") @ patch("api.dependencies.workspaces.WorkspaceRepository.get_workspace_by_id") @ patch("api.routes.workspaces.WorkspaceServiceRepository.get_active_workspace_services_for_workspace") @ patch('azure.cosmos.CosmosClient') diff --git a/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py b/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py index 66829a4a47..76aed496c7 100644 --- a/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py @@ -108,9 +108,9 @@ def __str__(self): @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('service_bus.airlock_request_status_update.WorkspaceRepository.create') @patch('logging.exception') -@patch('fastapi.FastAPI') +@patch("api.dependencies.database.Database.get_db_client") @patch("services.aad_authentication.AzureADAuthorization.get_workspace_role_assignment_details", return_value={"researcher_emails": ["researcher@outlook.com"], "owner_emails": ["owner@outlook.com"]}) -async def test_receiving_good_message(_, app, logging_mock, workspace_repo, airlock_request_repo, eg_client): +async def test_receiving_good_message(_, cosmos_client, logging_mock, workspace_repo, airlock_request_repo, eg_client): eg_client().send = AsyncMock() expected_airlock_request = sample_airlock_request() @@ -118,7 +118,7 @@ async def test_receiving_good_message(_, app, logging_mock, workspace_repo, airl airlock_request_repo.return_value.update_airlock_request.return_value = sample_airlock_request(status=AirlockRequestStatus.InReview) workspace_repo.return_value.get_workspace_by_id.return_value = sample_workspace() - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() complete_message = await airlockStatusUpdater.process_message(ServiceBusReceivedMessageMock(test_sb_step_result_message)) @@ -140,10 +140,10 @@ async def test_receiving_good_message(_, app, logging_mock, workspace_repo, airl @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('service_bus.airlock_request_status_update.WorkspaceRepository.create') @patch('services.logging.logger.exception') -@patch('fastapi.FastAPI') -async def test_receiving_bad_json_logs_error(app, logging_mock, workspace_repo, airlock_request_repo, payload): +@patch("api.dependencies.database.Database.get_db_client") +async def test_receiving_bad_json_logs_error(cosmos_client, logging_mock, workspace_repo, airlock_request_repo, payload): service_bus_received_message_mock = ServiceBusReceivedMessageMock(payload) - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() complete_message = await airlockStatusUpdater.process_message(service_bus_received_message_mock) @@ -156,12 +156,12 @@ async def test_receiving_bad_json_logs_error(app, logging_mock, workspace_repo, @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('services.logging.logger.exception') @patch('service_bus.airlock_request_status_update.ServiceBusClient') -@patch('fastapi.FastAPI') -async def test_updating_non_existent_airlock_request_error_is_logged(app, sb_client, logging_mock, airlock_request_repo, _): +@patch("api.dependencies.database.Database.get_db_client") +async def test_updating_non_existent_airlock_request_error_is_logged(cosmos_client, sb_client, logging_mock, airlock_request_repo, _): service_bus_received_message_mock = ServiceBusReceivedMessageMock(test_sb_step_result_message) airlock_request_repo.return_value.get_airlock_request_by_id.side_effect = EntityDoesNotExist - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() complete_message = await airlockStatusUpdater.process_message(service_bus_received_message_mock) @@ -173,12 +173,12 @@ async def test_updating_non_existent_airlock_request_error_is_logged(app, sb_cli @patch('service_bus.airlock_request_status_update.WorkspaceRepository.create') @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('services.logging.logger.exception') -@patch('fastapi.FastAPI') -async def test_when_updating_and_state_store_exception_error_is_logged(app, logging_mock, airlock_request_repo, _): +@patch("api.dependencies.database.Database.get_db_client") +async def test_when_updating_and_state_store_exception_error_is_logged(cosmos_client, logging_mock, airlock_request_repo, _): service_bus_received_message_mock = ServiceBusReceivedMessageMock(test_sb_step_result_message) airlock_request_repo.return_value.get_airlock_request_by_id.side_effect = Exception - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() complete_message = await airlockStatusUpdater.process_message(service_bus_received_message_mock) @@ -189,13 +189,13 @@ async def test_when_updating_and_state_store_exception_error_is_logged(app, logg @patch('service_bus.airlock_request_status_update.WorkspaceRepository.create') @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('services.logging.logger.error') -@patch('fastapi.FastAPI') -async def test_when_updating_and_current_status_differs_from_status_in_state_store_error_is_logged(app, logging_mock, airlock_request_repo, _): +@patch("api.dependencies.database.Database.get_db_client") +async def test_when_updating_and_current_status_differs_from_status_in_state_store_error_is_logged(cosmos_client, logging_mock, airlock_request_repo, _): service_bus_received_message_mock = ServiceBusReceivedMessageMock(test_sb_step_result_message) expected_airlock_request = sample_airlock_request(AirlockRequestStatus.Draft) airlock_request_repo.return_value.get_airlock_request_by_id.return_value = expected_airlock_request - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() complete_message = await airlockStatusUpdater.process_message(service_bus_received_message_mock) @@ -208,12 +208,12 @@ async def test_when_updating_and_current_status_differs_from_status_in_state_sto @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('services.logging.logger.exception') @patch('service_bus.airlock_request_status_update.ServiceBusClient') -@patch('fastapi.FastAPI') -async def test_when_updating_and_status_update_is_illegal_error_is_logged(app, sb_client, logging_mock, airlock_request_repo, _): +@patch("api.dependencies.database.Database.get_db_client") +async def test_when_updating_and_status_update_is_illegal_error_is_logged(cosmos_client, sb_client, logging_mock, airlock_request_repo, _): service_bus_received_message_mock = ServiceBusReceivedMessageMock(test_sb_step_result_message_with_invalid_status) airlock_request_repo.return_value.get_airlock_request_by_id.side_effect = HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - airlockStatusUpdater = AirlockStatusUpdater(app) + airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() complete_message = await airlockStatusUpdater.process_message(service_bus_received_message_mock) diff --git a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py index 71cfe2d843..b013914ba2 100644 --- a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py @@ -118,11 +118,11 @@ def create_sample_operation(resource_id, request_action): @pytest.mark.parametrize("payload", test_data) @patch('services.logging.logger.exception') -@patch('fastapi.FastAPI') -async def test_receiving_bad_json_logs_error(app, logging_mock, payload): +@patch("api.dependencies.database.Database.get_db_client") +async def test_receiving_bad_json_logs_error(cosmos_client, logging_mock, payload): service_bus_received_message_mock = ServiceBusReceivedMessageMock(payload) - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() complete_message = await status_updater.process_message(service_bus_received_message_mock) # bad message data will fail. we don't mark complete=true since we want the message in the DLQ @@ -138,15 +138,15 @@ async def test_receiving_bad_json_logs_error(app, logging_mock, payload): @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('services.logging.logger.exception') -@patch('fastapi.FastAPI') -async def test_receiving_good_message(app, logging_mock, resource_repo, operation_repo, _, __): +@patch("api.dependencies.database.Database.get_db_client") +async def test_receiving_good_message(cosmos_client, logging_mock, resource_repo, operation_repo, _, __): expected_workspace = create_sample_workspace_object(test_sb_message["id"]) resource_repo.return_value.get_resource_dict_by_id.return_value = expected_workspace.dict() operation = create_sample_operation(test_sb_message["id"], RequestAction.Install) operation_repo.return_value.get_operation_by_id.return_value = operation - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(ServiceBusReceivedMessageMock(test_sb_message)) @@ -161,14 +161,14 @@ async def test_receiving_good_message(app, logging_mock, resource_repo, operatio @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('services.logging.logger.exception') -@patch('fastapi.FastAPI') -async def test_when_updating_non_existent_workspace_error_is_logged(app, logging_mock, resource_repo, operation_repo, _, __): +@patch("api.dependencies.database.Database.get_db_client") +async def test_when_updating_non_existent_workspace_error_is_logged(cosmos_client, logging_mock, resource_repo, operation_repo, _, __): resource_repo.return_value.get_resource_dict_by_id.side_effect = EntityDoesNotExist operation = create_sample_operation(test_sb_message["id"], RequestAction.Install) operation_repo.return_value.get_operation_by_id.return_value = operation - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(ServiceBusReceivedMessageMock(test_sb_message)) @@ -182,14 +182,14 @@ async def test_when_updating_non_existent_workspace_error_is_logged(app, logging @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('services.logging.logger.exception') -@patch('fastapi.FastAPI') -async def test_when_updating_and_state_store_exception(app, logging_mock, resource_repo, operation_repo, _, __): +@patch("api.dependencies.database.Database.get_db_client") +async def test_when_updating_and_state_store_exception(cosmos_client, logging_mock, resource_repo, operation_repo, _, __): resource_repo.return_value.get_resource_dict_by_id.side_effect = Exception operation = create_sample_operation(test_sb_message["id"], RequestAction.Install) operation_repo.return_value.get_operation_by_id.return_value = operation - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(ServiceBusReceivedMessageMock(test_sb_message)) @@ -202,8 +202,8 @@ async def test_when_updating_and_state_store_exception(app, logging_mock, resour @patch("service_bus.deployment_status_updater.get_timestamp", return_value=FAKE_UPDATE_TIMESTAMP) @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') -@patch('fastapi.FastAPI') -async def test_state_transitions_from_deployed_to_deleted(app, resource_repo, operations_repo_mock, _, __, ___): +@patch("api.dependencies.database.Database.get_db_client") +async def test_state_transitions_from_deployed_to_deleted(cosmos_client, resource_repo, operations_repo_mock, _, __, ___): updated_message = test_sb_message updated_message["status"] = Status.Deleted updated_message["message"] = "Has been deleted" @@ -222,7 +222,7 @@ async def test_state_transitions_from_deployed_to_deleted(app, resource_repo, op expected_operation.status = Status.Deleted expected_operation.message = updated_message["message"] - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(service_bus_received_message_mock) @@ -234,8 +234,8 @@ async def test_state_transitions_from_deployed_to_deleted(app, resource_repo, op @patch('service_bus.deployment_status_updater.ResourceTemplateRepository.create') @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') -@patch('fastapi.FastAPI') -async def test_outputs_are_added_to_resource_item(app, resource_repo, operations_repo, _, __): +@patch("api.dependencies.database.Database.get_db_client") +async def test_outputs_are_added_to_resource_item(cosmos_client, resource_repo, operations_repo, _, __): received_message = test_sb_message_with_outputs received_message["status"] = Status.Deployed service_bus_received_message_mock = ServiceBusReceivedMessageMock(received_message) @@ -260,7 +260,7 @@ async def test_outputs_are_added_to_resource_item(app, resource_repo, operations operation = create_sample_operation(resource.id, RequestAction.UnInstall) operations_repo.return_value.get_operation_by_id.return_value = operation - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(service_bus_received_message_mock) @@ -272,8 +272,8 @@ async def test_outputs_are_added_to_resource_item(app, resource_repo, operations @patch('service_bus.deployment_status_updater.ResourceTemplateRepository.create') @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') -@patch('fastapi.FastAPI') -async def test_properties_dont_change_with_no_outputs(app, resource_repo, operations_repo, _, __): +@patch("api.dependencies.database.Database.get_db_client") +async def test_properties_dont_change_with_no_outputs(cosmos_client, resource_repo, operations_repo, _, __): received_message = test_sb_message received_message["status"] = Status.Deployed service_bus_received_message_mock = ServiceBusReceivedMessageMock(received_message) @@ -287,7 +287,7 @@ async def test_properties_dont_change_with_no_outputs(app, resource_repo, operat expected_resource = resource - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(service_bus_received_message_mock) @@ -301,8 +301,8 @@ async def test_properties_dont_change_with_no_outputs(app, resource_repo, operat @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('service_bus.helpers.ServiceBusClient') -@patch('fastapi.FastAPI') -async def test_multi_step_operation_sends_next_step(app, sb_sender_client, resource_repo, operations_repo, update_resource_for_step, _, __, multi_step_operation, user_resource_multi, basic_shared_service): +@patch("api.dependencies.database.Database.get_db_client") +async def test_multi_step_operation_sends_next_step(cosmos_client, sb_sender_client, resource_repo, operations_repo, update_resource_for_step, _, __, multi_step_operation, user_resource_multi, basic_shared_service): received_message = test_sb_message_multi_step_1_complete received_message["status"] = Status.Updated service_bus_received_message_mock = ServiceBusReceivedMessageMock(received_message) @@ -320,7 +320,7 @@ async def test_multi_step_operation_sends_next_step(app, sb_sender_client, resou operations_repo.return_value.get_operation_by_id.return_value = multi_step_operation update_resource_for_step.return_value = user_resource_multi - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(service_bus_received_message_mock) @@ -356,8 +356,8 @@ async def test_multi_step_operation_sends_next_step(app, sb_sender_client, resou @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('service_bus.helpers.ServiceBusClient') -@patch('fastapi.FastAPI') -async def test_multi_step_operation_ends_at_last_step(app, sb_sender_client, resource_repo, operations_repo, _, __, multi_step_operation, user_resource_multi, basic_shared_service): +@patch("api.dependencies.database.Database.get_db_client") +async def test_multi_step_operation_ends_at_last_step(cosmos_client, sb_sender_client, resource_repo, operations_repo, _, __, multi_step_operation, user_resource_multi, basic_shared_service): received_message = test_sb_message_multi_step_3_complete received_message["status"] = Status.Updated service_bus_received_message_mock = ServiceBusReceivedMessageMock(received_message) @@ -384,7 +384,7 @@ async def test_multi_step_operation_ends_at_last_step(app, sb_sender_client, res operations_repo.return_value.get_operation_by_id.return_value = in_flight_op - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() await status_updater.init_repos() complete_message = await status_updater.process_message(service_bus_received_message_mock) assert complete_message is True @@ -401,13 +401,13 @@ async def test_multi_step_operation_ends_at_last_step(app, sb_sender_client, res sb_sender_client().get_queue_sender().send_messages.assert_not_called() -@patch('fastapi.FastAPI') -async def test_convert_outputs_to_dict(app): +@patch("api.dependencies.database.Database.get_db_client") +async def test_convert_outputs_to_dict(cosmos_client): # Test case 1: Empty list of outputs outputs_list = [] expected_result = {} - status_updater = DeploymentStatusUpdater(app) + status_updater = DeploymentStatusUpdater() assert status_updater.convert_outputs_to_dict(outputs_list) == expected_result # Test case 2: List of outputs with mixed types diff --git a/api_app/tests_ma/test_services/test_health_checker.py b/api_app/tests_ma/test_services/test_health_checker.py index 76587ef134..61e76b85c4 100644 --- a/api_app/tests_ma/test_services/test_health_checker.py +++ b/api_app/tests_ma/test_services/test_health_checker.py @@ -12,8 +12,8 @@ @patch("core.credentials.get_credential_async") -@patch("api.dependencies.database.get_store_key") -@patch("api.dependencies.database.CosmosClient") +@patch("api.dependencies.database.Database._get_store_key") +@patch("api.dependencies.database.Database.cosmos_client") async def test_get_state_store_status_responding(_, get_store_key_mock, get_credential_async) -> None: get_store_key_mock.return_value = None status, message = await health_checker.create_state_store_status(get_credential_async) @@ -23,8 +23,8 @@ async def test_get_state_store_status_responding(_, get_store_key_mock, get_cred @patch("core.credentials.get_credential_async") -@patch("api.dependencies.database.get_store_key") -@patch("api.dependencies.database.get_db_client") +@patch("api.dependencies.database.Database._get_store_key") +@patch("api.dependencies.database.Database.get_db_client") async def test_get_state_store_status_not_responding(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None: get_credential_async.return_value = AsyncMock() get_store_key_mock.return_value = None @@ -37,8 +37,8 @@ async def test_get_state_store_status_not_responding(cosmos_client_mock, get_sto @patch("core.credentials.get_credential_async") -@patch("api.dependencies.database.get_store_key") -@patch("api.dependencies.database.get_db_client") +@patch("api.dependencies.database.Database._get_store_key") +@patch("api.dependencies.database.Database.get_db_client") async def test_get_state_store_status_other_exception(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None: get_credential_async.return_value = AsyncMock() get_store_key_mock.return_value = None From 93729d368b713c75908d0a1232c9fd2d20b894e1 Mon Sep 17 00:00:00 2001 From: marrobi Date: Fri, 22 Dec 2023 11:31:04 +0000 Subject: [PATCH 15/32] fix linting --- api_app/_version.py | 2 +- api_app/api/dependencies/database.py | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/api_app/_version.py b/api_app/_version.py index ddf3e33504..1317d7554a 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.17.13" +__version__ = "0.18.0" diff --git a/api_app/api/dependencies/database.py b/api_app/api/dependencies/database.py index 373be00df1..8a659b279b 100644 --- a/api_app/api/dependencies/database.py +++ b/api_app/api/dependencies/database.py @@ -2,7 +2,7 @@ from azure.cosmos.aio import CosmosClient from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient -from fastapi import Depends, FastAPI, HTTPException, Request, status +from fastapi import HTTPException, status from core.config import MANAGED_IDENTITY_CLIENT_ID, STATE_STORE_ENDPOINT, STATE_STORE_KEY, STATE_STORE_SSL_VERIFY, SUBSCRIPTION_ID, RESOURCE_MANAGER_ENDPOINT, CREDENTIAL_SCOPES, RESOURCE_GROUP_NAME, COSMOSDB_ACCOUNT_NAME from core.credentials import get_credential_async from db.errors import UnableToAccessDatabase @@ -13,6 +13,7 @@ class Singleton(type): _instances = {} + def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) @@ -81,12 +82,8 @@ async def get_db_client(self) -> CosmosClient: Database.cosmos_client = await self._connect_to_db() return self.cosmos_client - @classmethod - def get_repository(self, - repo_type: Type[BaseRepository], - ) -> Callable[[CosmosClient], BaseRepository]: - + def get_repository(self, repo_type: Type[BaseRepository]) -> Callable[[CosmosClient], BaseRepository]: async def _get_repo() -> BaseRepository: try: return await repo_type.create(self.cosmos_client) From 1fd7fa780750a003490f4fe5e2d7bb5114c58cb2 Mon Sep 17 00:00:00 2001 From: marrobi Date: Fri, 22 Dec 2023 14:59:04 +0000 Subject: [PATCH 16/32] Remove async with cosmos_client as this is closing the client on exit --- api_app/_version.py | 2 +- api_app/api/dependencies/database.py | 41 +++++++++++++++------------- api_app/api/routes/health.py | 2 +- api_app/services/health_checker.py | 7 ++--- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/api_app/_version.py b/api_app/_version.py index 1317d7554a..b0d7306ae4 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.18.0" +__version__ = "0.18.2" diff --git a/api_app/api/dependencies/database.py b/api_app/api/dependencies/database.py index 8a659b279b..2bac410eb2 100644 --- a/api_app/api/dependencies/database.py +++ b/api_app/api/dependencies/database.py @@ -4,7 +4,7 @@ from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient from fastapi import HTTPException, status from core.config import MANAGED_IDENTITY_CLIENT_ID, STATE_STORE_ENDPOINT, STATE_STORE_KEY, STATE_STORE_SSL_VERIFY, SUBSCRIPTION_ID, RESOURCE_MANAGER_ENDPOINT, CREDENTIAL_SCOPES, RESOURCE_GROUP_NAME, COSMOSDB_ACCOUNT_NAME -from core.credentials import get_credential_async +from core.credentials import get_credential from db.errors import UnableToAccessDatabase from db.repositories.base import BaseRepository from resources import strings @@ -30,28 +30,31 @@ def __init__(self): async def _connect_to_db(self) -> CosmosClient: logger.debug(f"Connecting to {STATE_STORE_ENDPOINT}") - async with get_credential_async() as credential: - if MANAGED_IDENTITY_CLIENT_ID: - logger.debug("Connecting with managed identity") + credential = get_credential() + if MANAGED_IDENTITY_CLIENT_ID: + logger.debug("Connecting with managed identity") + cosmos_client = CosmosClient( + url=STATE_STORE_ENDPOINT, + credential=credential + ) + else: + logger.debug("Connecting with key") + primary_master_key = await self._get_store_key(credential) + + if STATE_STORE_SSL_VERIFY: + logger.debug("Connecting with SSL verification") cosmos_client = CosmosClient( url=STATE_STORE_ENDPOINT, - credential=credential + credential=primary_master_key ) else: - logger.debug("Connecting with key") - primary_master_key = await self._get_store_key(credential) - - if STATE_STORE_SSL_VERIFY: - logger.debug("Connecting with SSL verification") - cosmos_client = CosmosClient( - url=STATE_STORE_ENDPOINT, credential=primary_master_key - ) - else: - logger.debug("Connecting without SSL verification") - # ignore TLS (setup is a pain) when using local Cosmos emulator. - cosmos_client = CosmosClient( - STATE_STORE_ENDPOINT, primary_master_key, connection_verify=False - ) + logger.debug("Connecting without SSL verification") + # ignore TLS (setup is a pain) when using local Cosmos emulator. + cosmos_client = CosmosClient( + url=STATE_STORE_ENDPOINT, + credential=primary_master_key, + connection_verify=False + ) logger.debug("Connection established") return cosmos_client diff --git a/api_app/api/routes/health.py b/api_app/api/routes/health.py index e0bab1f33b..535087b86a 100644 --- a/api_app/api/routes/health.py +++ b/api_app/api/routes/health.py @@ -16,7 +16,7 @@ async def health_check(request: Request) -> HealthCheck: # calling this endpoint frequently may result in API throttling. async with credentials.get_credential_async() as credential: cosmos, sb, rp = await asyncio.gather( - create_state_store_status(request), + create_state_store_status(), create_service_bus_status(credential), create_resource_processor_status(credential) ) diff --git a/api_app/services/health_checker.py b/api_app/services/health_checker.py index 6314fa9a79..6bafcbd499 100644 --- a/api_app/services/health_checker.py +++ b/api_app/services/health_checker.py @@ -12,14 +12,13 @@ from services.logging import logger -async def create_state_store_status(request) -> Tuple[StatusEnum, str]: +async def create_state_store_status() -> Tuple[StatusEnum, str]: status = StatusEnum.ok message = "" try: cosmos_client = await Database().get_db_client() - async with cosmos_client: - list_databases_response = cosmos_client.list_databases() - [database async for database in list_databases_response] + list_databases_response = cosmos_client.list_databases() + [database async for database in list_databases_response] except exceptions.ServiceRequestError: status = StatusEnum.not_ok message = strings.STATE_STORE_ENDPOINT_NOT_RESPONDING From 0a968184d02df9dff2e4d40a5a236f7d08edb0e1 Mon Sep 17 00:00:00 2001 From: marrobi Date: Fri, 22 Dec 2023 15:21:02 +0000 Subject: [PATCH 17/32] Create method to get creds async without context manager --- api_app/_version.py | 2 +- api_app/api/dependencies/database.py | 4 +- api_app/api/routes/health.py | 2 +- api_app/core/credentials.py | 23 ++++++- api_app/event_grid/helpers.py | 2 +- .../airlock_request_status_update.py | 2 +- .../service_bus/deployment_status_updater.py | 2 +- api_app/service_bus/helpers.py | 2 +- api_app/tests_ma/test_db/test_events.py | 8 +-- .../test_services/test_health_checker.py | 65 +++++++++---------- 10 files changed, 64 insertions(+), 48 deletions(-) diff --git a/api_app/_version.py b/api_app/_version.py index b0d7306ae4..30e65e30d0 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.18.2" +__version__ = "0.18.3" diff --git a/api_app/api/dependencies/database.py b/api_app/api/dependencies/database.py index 2bac410eb2..ab0486dca1 100644 --- a/api_app/api/dependencies/database.py +++ b/api_app/api/dependencies/database.py @@ -4,7 +4,7 @@ from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient from fastapi import HTTPException, status from core.config import MANAGED_IDENTITY_CLIENT_ID, STATE_STORE_ENDPOINT, STATE_STORE_KEY, STATE_STORE_SSL_VERIFY, SUBSCRIPTION_ID, RESOURCE_MANAGER_ENDPOINT, CREDENTIAL_SCOPES, RESOURCE_GROUP_NAME, COSMOSDB_ACCOUNT_NAME -from core.credentials import get_credential +from core.credentials import get_credential_async from db.errors import UnableToAccessDatabase from db.repositories.base import BaseRepository from resources import strings @@ -30,7 +30,7 @@ def __init__(self): async def _connect_to_db(self) -> CosmosClient: logger.debug(f"Connecting to {STATE_STORE_ENDPOINT}") - credential = get_credential() + credential = await get_credential_async() if MANAGED_IDENTITY_CLIENT_ID: logger.debug("Connecting with managed identity") cosmos_client = CosmosClient( diff --git a/api_app/api/routes/health.py b/api_app/api/routes/health.py index 535087b86a..9d0ef42c12 100644 --- a/api_app/api/routes/health.py +++ b/api_app/api/routes/health.py @@ -14,7 +14,7 @@ async def health_check(request: Request) -> HealthCheck: # The health endpoint checks the status of key components of the system. # Note that Resource Processor checks incur Azure management calls, so # calling this endpoint frequently may result in API throttling. - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_cm() as credential: cosmos, sb, rp = await asyncio.gather( create_state_store_status(), create_service_bus_status(credential), diff --git a/api_app/core/credentials.py b/api_app/core/credentials.py index 8780248a63..331a44423c 100644 --- a/api_app/core/credentials.py +++ b/api_app/core/credentials.py @@ -31,8 +31,29 @@ def get_credential() -> TokenCredential: ) -@asynccontextmanager async def get_credential_async() -> TokenCredential: + """ + Context manager which yields the default credentials. + """ + managed_identity = config.MANAGED_IDENTITY_CLIENT_ID + credential = ( + ChainedTokenCredentialASync( + ManagedIdentityCredentialASync(client_id=managed_identity) + ) + if managed_identity + else DefaultAzureCredentialASync(authority=urlparse(config.AAD_AUTHORITY_URL).netloc, + exclude_shared_token_cache_credential=True, + exclude_workload_identity_credential=True, + exclude_developer_cli_credential=True, + exclude_managed_identity_credential=True, + exclude_powershell_credential=True + ) + ) + return credential + + +@asynccontextmanager +async def get_credential_async_cm() -> TokenCredential: """ Context manager which yields the default credentials. """ diff --git a/api_app/event_grid/helpers.py b/api_app/event_grid/helpers.py index ed96ec695c..debfcfdfc3 100644 --- a/api_app/event_grid/helpers.py +++ b/api_app/event_grid/helpers.py @@ -4,7 +4,7 @@ async def publish_event(event: EventGridEvent, topic_endpoint: str): - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_cm() as credential: client = EventGridPublisherClient(topic_endpoint, credential) async with client: await client.send([event]) diff --git a/api_app/service_bus/airlock_request_status_update.py b/api_app/service_bus/airlock_request_status_update.py index babb302702..7255f3e136 100644 --- a/api_app/service_bus/airlock_request_status_update.py +++ b/api_app/service_bus/airlock_request_status_update.py @@ -32,7 +32,7 @@ async def receive_messages(self): with tracer.start_as_current_span("airlock_receive_messages"): while True: try: - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_cm() as credential: service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential) receiver = service_bus_client.get_queue_receiver(queue_name=config.SERVICE_BUS_STEP_RESULT_QUEUE) logger.info(f"Looking for new messages on {config.SERVICE_BUS_STEP_RESULT_QUEUE} queue...") diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 255b25d1c3..26ebfe9309 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -41,7 +41,7 @@ async def receive_messages(self): with tracer.start_as_current_span("deployment_status_receive_messages"): while True: try: - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_cm() as credential: service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential) logger.info(f"Looking for new messages on {config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue...") diff --git a/api_app/service_bus/helpers.py b/api_app/service_bus/helpers.py index 55ae5c1b20..79f11b687b 100644 --- a/api_app/service_bus/helpers.py +++ b/api_app/service_bus/helpers.py @@ -24,7 +24,7 @@ async def _send_message(message: ServiceBusMessage, queue: str): :param queue: The Service Bus queue to send the message to. :type queue: str """ - async with credentials.get_credential_async() as credential: + async with credentials.get_credential_async_cm() as credential: service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential) async with service_bus_client: diff --git a/api_app/tests_ma/test_db/test_events.py b/api_app/tests_ma/test_db/test_events.py index 7b93a16a82..016067b105 100644 --- a/api_app/tests_ma/test_db/test_events.py +++ b/api_app/tests_ma/test_db/test_events.py @@ -8,8 +8,8 @@ @patch("db.events.get_credential") @patch("db.events.CosmosDBManagementClient") -async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_credential_async_mock): - get_credential_async_mock.return_value = AsyncMock() +async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_credential_async_cm_mock): + get_credential_async_cm_mock.return_value = AsyncMock() cosmos_db_mgmt_client_mock.return_value = MagicMock() result = await events.bootstrap_database() @@ -19,8 +19,8 @@ async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_creden @patch("db.events.get_credential") @patch("db.events.CosmosDBManagementClient") -async def test_bootstrap_database_failure(cosmos_db_mgmt_client_mock, get_credential_async_mock): - get_credential_async_mock.return_value = AsyncMock() +async def test_bootstrap_database_failure(cosmos_db_mgmt_client_mock, get_credential_async_cm_mock): + get_credential_async_cm_mock.return_value = AsyncMock() cosmos_db_mgmt_client_mock.side_effect = AzureError("some error") result = await events.bootstrap_database() diff --git a/api_app/tests_ma/test_services/test_health_checker.py b/api_app/tests_ma/test_services/test_health_checker.py index 61e76b85c4..d918b5c87e 100644 --- a/api_app/tests_ma/test_services/test_health_checker.py +++ b/api_app/tests_ma/test_services/test_health_checker.py @@ -11,84 +11,79 @@ pytestmark = pytest.mark.asyncio -@patch("core.credentials.get_credential_async") @patch("api.dependencies.database.Database._get_store_key") @patch("api.dependencies.database.Database.cosmos_client") -async def test_get_state_store_status_responding(_, get_store_key_mock, get_credential_async) -> None: +async def test_get_state_store_status_responding(_, get_store_key_mock) -> None: get_store_key_mock.return_value = None - status, message = await health_checker.create_state_store_status(get_credential_async) + status, message = await health_checker.create_state_store_status() assert status == StatusEnum.ok assert message == "" -@patch("core.credentials.get_credential_async") @patch("api.dependencies.database.Database._get_store_key") @patch("api.dependencies.database.Database.get_db_client") -async def test_get_state_store_status_not_responding(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_state_store_status_not_responding(cosmos_client_mock, get_store_key_mock) -> None: get_store_key_mock.return_value = None cosmos_client_mock.return_value = None cosmos_client_mock.side_effect = ServiceRequestError(message="some message") - status, message = await health_checker.create_state_store_status(get_credential_async) + status, message = await health_checker.create_state_store_status() assert status == StatusEnum.not_ok assert message == strings.STATE_STORE_ENDPOINT_NOT_RESPONDING -@patch("core.credentials.get_credential_async") @patch("api.dependencies.database.Database._get_store_key") @patch("api.dependencies.database.Database.get_db_client") -async def test_get_state_store_status_other_exception(cosmos_client_mock, get_store_key_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_state_store_status_other_exception(cosmos_client_mock, get_store_key_mock) -> None: get_store_key_mock.return_value = None cosmos_client_mock.return_value = None cosmos_client_mock.side_effect = Exception() - status, message = await health_checker.create_state_store_status(get_credential_async) + status, message = await health_checker.create_state_store_status() assert status == StatusEnum.not_ok assert message == strings.UNSPECIFIED_ERROR -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_cm") @patch("services.health_checker.ServiceBusClient") -async def test_get_service_bus_status_responding(service_bus_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_service_bus_status_responding(service_bus_client_mock, get_credential_async_cm) -> None: + get_credential_async_cm.return_value = AsyncMock() service_bus_client_mock().get_queue_receiver.__aenter__.return_value = AsyncMock() - status, message = await health_checker.create_service_bus_status(get_credential_async) + status, message = await health_checker.create_service_bus_status(get_credential_async_cm) assert status == StatusEnum.ok assert message == "" -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_cm") @patch("services.health_checker.ServiceBusClient") -async def test_get_service_bus_status_not_responding(service_bus_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_service_bus_status_not_responding(service_bus_client_mock, get_credential_async_cm) -> None: + get_credential_async_cm.return_value = AsyncMock() service_bus_client_mock.return_value = None service_bus_client_mock.side_effect = ServiceBusConnectionError(message="some message") - status, message = await health_checker.create_service_bus_status(get_credential_async) + status, message = await health_checker.create_service_bus_status(get_credential_async_cm) assert status == StatusEnum.not_ok assert message == strings.SERVICE_BUS_NOT_RESPONDING -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_cm") @patch("services.health_checker.ServiceBusClient") -async def test_get_service_bus_status_other_exception(service_bus_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_service_bus_status_other_exception(service_bus_client_mock, get_credential_async_cm) -> None: + get_credential_async_cm.return_value = AsyncMock() service_bus_client_mock.return_value = None service_bus_client_mock.side_effect = Exception() - status, message = await health_checker.create_service_bus_status(get_credential_async) + status, message = await health_checker.create_service_bus_status(get_credential_async_cm) assert status == StatusEnum.not_ok assert message == strings.UNSPECIFIED_ERROR -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_cm") @patch("services.health_checker.ComputeManagementClient") -async def test_get_resource_processor_status_healthy(resource_processor_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_resource_processor_status_healthy(resource_processor_client_mock, get_credential_async_cm) -> None: + get_credential_async_cm.return_value = AsyncMock() resource_processor_client_mock().virtual_machine_scale_set_vms.return_value = AsyncMock() vm_mock = MagicMock() vm_mock.instance_id = 'mocked_id' @@ -100,16 +95,16 @@ async def test_get_resource_processor_status_healthy(resource_processor_client_m awaited_mock.set_result(instance_view_mock) resource_processor_client_mock().virtual_machine_scale_set_vms.get_instance_view.return_value = awaited_mock - status, message = await health_checker.create_resource_processor_status(get_credential_async) + status, message = await health_checker.create_resource_processor_status(get_credential_async_cm) assert status == StatusEnum.ok assert message == "" -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_cm") @patch("services.health_checker.ComputeManagementClient", return_value=MagicMock()) -async def test_get_resource_processor_status_not_healthy(resource_processor_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_resource_processor_status_not_healthy(resource_processor_client_mock, get_credential_async_cm) -> None: + get_credential_async_cm.return_value = AsyncMock() resource_processor_client_mock().virtual_machine_scale_set_vms.return_value = AsyncMock() vm_mock = MagicMock() @@ -122,19 +117,19 @@ async def test_get_resource_processor_status_not_healthy(resource_processor_clie awaited_mock.set_result(instance_view_mock) resource_processor_client_mock().virtual_machine_scale_set_vms.get_instance_view.return_value = awaited_mock - status, message = await health_checker.create_resource_processor_status(get_credential_async) + status, message = await health_checker.create_resource_processor_status(get_credential_async_cm) assert status == StatusEnum.not_ok assert message == strings.RESOURCE_PROCESSOR_GENERAL_ERROR_MESSAGE -@patch("core.credentials.get_credential_async") +@patch("core.credentials.get_credential_async_cm") @patch("services.health_checker.ComputeManagementClient") -async def test_get_resource_processor_status_other_exception(resource_processor_client_mock, get_credential_async) -> None: - get_credential_async.return_value = AsyncMock() +async def test_get_resource_processor_status_other_exception(resource_processor_client_mock, get_credential_async_cm) -> None: + get_credential_async_cm.return_value = AsyncMock() resource_processor_client_mock.return_value = None resource_processor_client_mock.side_effect = Exception() - status, message = await health_checker.create_resource_processor_status(get_credential_async) + status, message = await health_checker.create_resource_processor_status(get_credential_async_cm) assert status == StatusEnum.not_ok assert message == strings.UNSPECIFIED_ERROR From 6ab2fba111f89eddf8db0d020dd805e8fe6ba4a5 Mon Sep 17 00:00:00 2001 From: marrobi Date: Tue, 2 Jan 2024 12:36:21 +0000 Subject: [PATCH 18/32] Factor out cosmos client from everywhere but base repo and health checks --- api_app/_version.py | 2 +- api_app/api/dependencies/database.py | 4 ++-- api_app/db/migrations/airlock.py | 6 ++---- api_app/db/migrations/resources.py | 6 ++---- api_app/db/migrations/shared_services.py | 6 ++---- api_app/db/migrations/workspaces.py | 6 ++---- api_app/db/repositories/airlock_requests.py | 4 ++-- api_app/db/repositories/base.py | 5 ++--- api_app/db/repositories/operations.py | 5 ++--- api_app/db/repositories/resource_templates.py | 5 ++--- api_app/db/repositories/resources.py | 7 +++---- api_app/db/repositories/resources_history.py | 5 ++--- api_app/db/repositories/shared_services.py | 5 ++--- api_app/db/repositories/user_resources.py | 5 ++--- api_app/db/repositories/workspace_services.py | 5 ++--- api_app/db/repositories/workspaces.py | 5 ++--- api_app/main.py | 2 -- .../service_bus/airlock_request_status_update.py | 5 ++--- api_app/service_bus/deployment_status_updater.py | 9 ++++----- api_app/services/aad_authentication.py | 3 +-- .../test_api/test_routes/test_resource_helpers.py | 15 ++++++--------- .../test_api/test_routes/test_workspaces.py | 2 +- .../test_migrations/test_workspace_migration.py | 5 ++--- .../test_airlock_request_repository.py | 5 ++--- .../test_repositories/test_base_repository.py | 2 +- .../test_operation_repository.py | 15 ++++++--------- .../test_resource_history_repository.py | 5 ++--- .../test_repositories/test_resource_repository.py | 10 ++++------ .../test_resource_templates_repository.py | 5 ++--- .../test_shared_service_repository.py | 10 ++++------ .../test_shared_service_templates_repository.py | 5 ++--- .../test_user_resource_repository.py | 5 ++--- .../test_user_resource_templates_repository.py | 5 ++--- .../test_workpaces_repository.py | 10 ++++------ .../test_workpaces_service_repository.py | 10 ++++------ api_app/tests_ma/test_services/test_airlock.py | 2 +- 36 files changed, 84 insertions(+), 127 deletions(-) diff --git a/api_app/_version.py b/api_app/_version.py index 30e65e30d0..e61f7a55d6 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.18.3" +__version__ = "0.18.4" diff --git a/api_app/api/dependencies/database.py b/api_app/api/dependencies/database.py index ab0486dca1..9c08caa8aa 100644 --- a/api_app/api/dependencies/database.py +++ b/api_app/api/dependencies/database.py @@ -86,10 +86,10 @@ async def get_db_client(self) -> CosmosClient: return self.cosmos_client @classmethod - def get_repository(self, repo_type: Type[BaseRepository]) -> Callable[[CosmosClient], BaseRepository]: + def get_repository(self, repo_type: Type[BaseRepository]) -> Callable: async def _get_repo() -> BaseRepository: try: - return await repo_type.create(self.cosmos_client) + return await repo_type.create() except UnableToAccessDatabase: logger.exception(strings.STATE_STORE_ENDPOINT_NOT_RESPONDING) raise HTTPException( diff --git a/api_app/db/migrations/airlock.py b/api_app/db/migrations/airlock.py index 299441fb63..9c834aba07 100644 --- a/api_app/db/migrations/airlock.py +++ b/api_app/db/migrations/airlock.py @@ -1,15 +1,13 @@ -from azure.cosmos.aio import CosmosClient from resources import strings from db.repositories.airlock_requests import AirlockRequestRepository class AirlockMigration(AirlockRequestRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = AirlockMigration() - resource_repo = await super().create(client) + resource_repo = await super().create() cls._container = resource_repo._container - cls._client = resource_repo._client return cls async def add_created_by_and_rename_in_history(self) -> int: diff --git a/api_app/db/migrations/resources.py b/api_app/db/migrations/resources.py index c0cdcf25e7..4c99bff4cf 100644 --- a/api_app/db/migrations/resources.py +++ b/api_app/db/migrations/resources.py @@ -1,5 +1,4 @@ import uuid -from azure.cosmos.aio import CosmosClient from db.repositories.operations import OperationRepository from db.repositories.resources import ResourceRepository from db.repositories.resources_history import ResourceHistoryRepository @@ -7,11 +6,10 @@ class ResourceMigration(ResourceRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = ResourceMigration() - resource_repo = await super().create(client) + resource_repo = await super().create() cls._container = resource_repo._container - cls._client = resource_repo._client return cls async def add_deployment_status_field(self, operations_repository: OperationRepository) -> int: diff --git a/api_app/db/migrations/shared_services.py b/api_app/db/migrations/shared_services.py index 575ac74bb2..621991314a 100644 --- a/api_app/db/migrations/shared_services.py +++ b/api_app/db/migrations/shared_services.py @@ -1,4 +1,3 @@ -from azure.cosmos.aio import CosmosClient import semantic_version from db.repositories.shared_services import SharedServiceRepository @@ -8,11 +7,10 @@ class SharedServiceMigration(SharedServiceRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = SharedServiceMigration() - resource_repo = await super().create(client) + resource_repo = await super().create() cls._container = resource_repo._container - cls._client = resource_repo._client return cls async def deleteDuplicatedSharedServices(self) -> bool: diff --git a/api_app/db/migrations/workspaces.py b/api_app/db/migrations/workspaces.py index 79209ec30a..e8a422080b 100644 --- a/api_app/db/migrations/workspaces.py +++ b/api_app/db/migrations/workspaces.py @@ -1,4 +1,3 @@ -from azure.cosmos.aio import CosmosClient import semantic_version from db.repositories.workspaces import WorkspaceRepository @@ -7,11 +6,10 @@ class WorkspaceMigration(WorkspaceRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = WorkspaceMigration() - resource_repo = await super().create(client) + resource_repo = await super().create() cls._container = resource_repo._container - cls._client = resource_repo._client return cls async def moveAuthInformationToProperties(self) -> bool: diff --git a/api_app/db/repositories/airlock_requests.py b/api_app/db/repositories/airlock_requests.py index ba683723d6..0ec78f0452 100644 --- a/api_app/db/repositories/airlock_requests.py +++ b/api_app/db/repositories/airlock_requests.py @@ -21,9 +21,9 @@ class AirlockRequestRepository(BaseRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = AirlockRequestRepository() - await super().create(client, config.STATE_STORE_AIRLOCK_REQUESTS_CONTAINER) + await super().create(config.STATE_STORE_AIRLOCK_REQUESTS_CONTAINER) return cls @staticmethod diff --git a/api_app/db/repositories/base.py b/api_app/db/repositories/base.py index 15ea9bf0fc..cc1bc78385 100644 --- a/api_app/db/repositories/base.py +++ b/api_app/db/repositories/base.py @@ -1,5 +1,5 @@ from typing import Optional -from azure.cosmos.aio import CosmosClient, ContainerProxy +from azure.cosmos.aio import ContainerProxy from azure.core import MatchConditions from pydantic import BaseModel @@ -9,8 +9,7 @@ class BaseRepository: @classmethod - async def create(cls, client: CosmosClient, container_name: Optional[str] = None): - cls._client: CosmosClient = client + async def create(cls, container_name: Optional[str] = None): cls._container: ContainerProxy = await cls._get_container(container_name) return cls diff --git a/api_app/db/repositories/operations.py b/api_app/db/repositories/operations.py index 394e6713e8..8489d231d6 100644 --- a/api_app/db/repositories/operations.py +++ b/api_app/db/repositories/operations.py @@ -2,7 +2,6 @@ import uuid from typing import List -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from db.repositories.resource_templates import ResourceTemplateRepository from resources import strings @@ -19,9 +18,9 @@ class OperationRepository(BaseRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = OperationRepository() - await super().create(client, config.STATE_STORE_OPERATIONS_CONTAINER) + await super().create(config.STATE_STORE_OPERATIONS_CONTAINER) return cls @staticmethod diff --git a/api_app/db/repositories/resource_templates.py b/api_app/db/repositories/resource_templates.py index 471269d100..288b096883 100644 --- a/api_app/db/repositories/resource_templates.py +++ b/api_app/db/repositories/resource_templates.py @@ -1,7 +1,6 @@ import uuid from typing import List, Optional, Union -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from core import config @@ -16,9 +15,9 @@ class ResourceTemplateRepository(BaseRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = ResourceTemplateRepository() - await super().create(client, config.STATE_STORE_RESOURCE_TEMPLATES_CONTAINER) + await super().create(config.STATE_STORE_RESOURCE_TEMPLATES_CONTAINER) return cls @staticmethod diff --git a/api_app/db/repositories/resources.py b/api_app/db/repositories/resources.py index 6cbc464080..a1740dd167 100644 --- a/api_app/db/repositories/resources.py +++ b/api_app/db/repositories/resources.py @@ -3,7 +3,6 @@ from datetime import datetime from typing import Optional, Tuple, List -from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosResourceNotFoundError from core import config from db.errors import VersionDowngradeDenied, EntityDoesNotExist, MajorVersionUpdateDenied, TargetTemplateVersionDoesNotExist, UserNotAuthorizedToUseTemplate @@ -25,9 +24,9 @@ class ResourceRepository(BaseRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = ResourceRepository() - await super().create(client, config.STATE_STORE_RESOURCES_CONTAINER) + await super().create(config.STATE_STORE_RESOURCES_CONTAINER) return cls @staticmethod @@ -46,7 +45,7 @@ def _validate_resource_parameters(resource_input, resource_template): validate(instance=resource_input["properties"], schema=resource_template) async def _get_enriched_template(self, template_name: str, resource_type: ResourceType, parent_template_name: str = "") -> dict: - template_repo = await ResourceTemplateRepository.create(self._client) + template_repo = await ResourceTemplateRepository.create() template = await template_repo.get_current_template(template_name, resource_type, parent_template_name) return template_repo.enrich_template(template) diff --git a/api_app/db/repositories/resources_history.py b/api_app/db/repositories/resources_history.py index 6005619b39..2ccfc061e5 100644 --- a/api_app/db/repositories/resources_history.py +++ b/api_app/db/repositories/resources_history.py @@ -1,6 +1,5 @@ from typing import List import uuid -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from db.errors import EntityDoesNotExist @@ -12,9 +11,9 @@ class ResourceHistoryRepository(BaseRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = ResourceHistoryRepository() - await super().create(client, config.STATE_STORE_RESOURCES_HISTORY_CONTAINER) + await super().create(config.STATE_STORE_RESOURCES_HISTORY_CONTAINER) return cls @staticmethod diff --git a/api_app/db/repositories/shared_services.py b/api_app/db/repositories/shared_services.py index c7cc1e988a..af83ab3116 100644 --- a/api_app/db/repositories/shared_services.py +++ b/api_app/db/repositories/shared_services.py @@ -2,7 +2,6 @@ from typing import List, Tuple import uuid -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from models.domain.resource_template import ResourceTemplate from models.domain.authentication import User @@ -18,9 +17,9 @@ class SharedServiceRepository(ResourceRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = SharedServiceRepository() - await super().create(client) + await super().create() return cls @staticmethod diff --git a/api_app/db/repositories/user_resources.py b/api_app/db/repositories/user_resources.py index 4cffe296a8..dd093e2419 100644 --- a/api_app/db/repositories/user_resources.py +++ b/api_app/db/repositories/user_resources.py @@ -1,7 +1,6 @@ import uuid from typing import List, Tuple -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from db.repositories.resources_history import ResourceHistoryRepository from models.domain.resource_template import ResourceTemplate @@ -18,9 +17,9 @@ class UserResourceRepository(ResourceRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = UserResourceRepository() - await super().create(client) + await super().create() return cls @staticmethod diff --git a/api_app/db/repositories/workspace_services.py b/api_app/db/repositories/workspace_services.py index 48523dbcac..5f614aaa94 100644 --- a/api_app/db/repositories/workspace_services.py +++ b/api_app/db/repositories/workspace_services.py @@ -1,7 +1,6 @@ import uuid from typing import List, Tuple -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from db.repositories.resources_history import ResourceHistoryRepository from models.domain.resource_template import ResourceTemplate @@ -19,9 +18,9 @@ class WorkspaceServiceRepository(ResourceRepository): @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = WorkspaceServiceRepository() - await super().create(client) + await super().create() return cls @staticmethod diff --git a/api_app/db/repositories/workspaces.py b/api_app/db/repositories/workspaces.py index 23af53109b..8065d48be0 100644 --- a/api_app/db/repositories/workspaces.py +++ b/api_app/db/repositories/workspaces.py @@ -1,7 +1,6 @@ import uuid from typing import List, Tuple -from azure.cosmos.aio import CosmosClient from pydantic import parse_obj_as from db.repositories.resources_history import ResourceHistoryRepository from models.domain.resource_template import ResourceTemplate @@ -28,9 +27,9 @@ class WorkspaceRepository(ResourceRepository): predefined_address_spaces = {"small": 24, "medium": 22, "large": 16} @classmethod - async def create(cls, client: CosmosClient): + async def create(cls): cls = WorkspaceRepository() - await super().create(client) + await super().create() return cls @staticmethod diff --git a/api_app/main.py b/api_app/main.py index 8d88e9f6e8..0bdc769141 100644 --- a/api_app/main.py +++ b/api_app/main.py @@ -24,8 +24,6 @@ @asynccontextmanager async def lifespan(app: FastAPI): - app.state.cosmos_client = None - while not await bootstrap_database(): await asyncio.sleep(5) logger.warning("Database connection could not be established") diff --git a/api_app/service_bus/airlock_request_status_update.py b/api_app/service_bus/airlock_request_status_update.py index 7255f3e136..ad81402ebe 100644 --- a/api_app/service_bus/airlock_request_status_update.py +++ b/api_app/service_bus/airlock_request_status_update.py @@ -24,9 +24,8 @@ def __init__(self): pass async def init_repos(self): - db_client = await Database().get_db_client() - self.airlock_request_repo = await AirlockRequestRepository.create(db_client) - self.workspace_repo = await WorkspaceRepository.create(db_client) + self.airlock_request_repo = await AirlockRequestRepository.create() + self.workspace_repo = await WorkspaceRepository.create() async def receive_messages(self): with tracer.start_as_current_span("airlock_receive_messages"): diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 26ebfe9309..de2a563e86 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -28,11 +28,10 @@ def __init__(self): pass async def init_repos(self): - db_client = await Database().get_db_client() - self.operations_repo = await OperationRepository.create(db_client) - self.resource_repo = await ResourceRepository.create(db_client) - self.resource_template_repo = await ResourceTemplateRepository.create(db_client) - self.resource_history_repo = await ResourceHistoryRepository.create(db_client) + self.operations_repo = await OperationRepository.create() + self.resource_repo = await ResourceRepository.create() + self.resource_template_repo = await ResourceTemplateRepository.create() + self.resource_history_repo = await ResourceHistoryRepository.create() def run(self, *args, **kwargs): asyncio.run(self.receive_messages()) diff --git a/api_app/services/aad_authentication.py b/api_app/services/aad_authentication.py index 8842a0c607..bc39c48c87 100644 --- a/api_app/services/aad_authentication.py +++ b/api_app/services/aad_authentication.py @@ -121,8 +121,7 @@ async def _fetch_ws_app_reg_id_from_ws_id(request: Request) -> str: try: workspace_id = request.path_params['workspace_id'] - db_client = Database().get_db_client() - ws_repo = await WorkspaceRepository.create(db_client) + ws_repo = await WorkspaceRepository.create() workspace = await ws_repo.get_workspace_by_id(workspace_id) ws_app_reg_id = "" diff --git a/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py b/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py index 092662930e..34bf8659a1 100644 --- a/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py +++ b/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py @@ -30,23 +30,20 @@ @pytest_asyncio.fixture async def resource_repo() -> ResourceRepository: with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): - with patch("azure.cosmos.CosmosClient") as cosmos_client_mock: - resource_repo_mock = await ResourceRepository.create(cosmos_client_mock) - yield resource_repo_mock + resource_repo_mock = await ResourceRepository.create() + yield resource_repo_mock @pytest_asyncio.fixture async def operations_repo() -> OperationRepository: - with patch("azure.cosmos.CosmosClient") as cosmos_client_mock: - operation_repo_mock = await OperationRepository.create(cosmos_client_mock) - yield operation_repo_mock + operation_repo_mock = await OperationRepository.create() + yield operation_repo_mock @pytest_asyncio.fixture async def resource_history_repo() -> ResourceHistoryRepository: - with patch("azure.cosmos.CosmosClient") as cosmos_client_mock: - resource_history_repo_mock = await ResourceHistoryRepository.create(cosmos_client_mock) - yield resource_history_repo_mock + resource_history_repo_mock = await ResourceHistoryRepository.create() + yield resource_history_repo_mock def sample_resource(workspace_id=WORKSPACE_ID): diff --git a/api_app/tests_ma/test_api/test_routes/test_workspaces.py b/api_app/tests_ma/test_api/test_routes/test_workspaces.py index feed3f2bc2..9b7ae7c18c 100644 --- a/api_app/tests_ma/test_api/test_routes/test_workspaces.py +++ b/api_app/tests_ma/test_api/test_routes/test_workspaces.py @@ -683,7 +683,7 @@ async def test_delete_workspace_returns_400_if_associated_workspace_services_are @ patch('api.routes.resource_helpers.send_resource_request_message', return_value=sample_resource_operation(resource_id=WORKSPACE_ID, operation_id=OPERATION_ID)) async def test_delete_workspace_sends_a_request_message_to_uninstall_the_workspace(self, send_request_message_mock, cosmos_client_mock, __, get_workspace_mock, get_repository_mock, resource_template_repo, ___, disabled_workspace, app, client, basic_resource_template): get_workspace_mock.return_value = disabled_workspace - get_repository_mock.side_effects = [await WorkspaceRepository.create(cosmos_client_mock), await WorkspaceServiceRepository.create(cosmos_client_mock)] + get_repository_mock.side_effects = [await WorkspaceRepository.create(), await WorkspaceServiceRepository.create()] resource_template_repo.return_value = basic_resource_template await client.delete(app.url_path_for(strings.API_DELETE_WORKSPACE, workspace_id=WORKSPACE_ID)) diff --git a/api_app/tests_ma/test_db/test_migrations/test_workspace_migration.py b/api_app/tests_ma/test_db/test_migrations/test_workspace_migration.py index c15310f6d3..5c4bfaaa7d 100644 --- a/api_app/tests_ma/test_db/test_migrations/test_workspace_migration.py +++ b/api_app/tests_ma/test_db/test_migrations/test_workspace_migration.py @@ -12,9 +12,8 @@ @pytest_asyncio.fixture async def workspace_migrator(): with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - workspace_migrator = await WorkspaceMigration.create(cosmos_client_mock) - yield workspace_migrator + workspace_migrator = await WorkspaceMigration.create() + yield workspace_migrator def get_sample_old_workspace(workspace_id: str = "7ab18f7e-ee8f-4202-8d46-747818ec76f4", spec_workspace_id: str = "0001") -> dict: diff --git a/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py b/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py index ffff8a3c18..0b91bb59f4 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py @@ -47,9 +47,8 @@ @pytest_asyncio.fixture async def airlock_request_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - airlock_request_repo_mock = await AirlockRequestRepository.create(cosmos_client_mock) - yield airlock_request_repo_mock + airlock_request_repo_mock = await AirlockRequestRepository.create() + yield airlock_request_repo_mock @pytest.fixture diff --git a/api_app/tests_ma/test_db/test_repositories/test_base_repository.py b/api_app/tests_ma/test_db/test_repositories/test_base_repository.py index 3298b5049e..4d8c859edb 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_base_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_base_repository.py @@ -12,4 +12,4 @@ async def test_instantiating_a_repo_raises_unable_to_access_database_if_database with patch("azure.cosmos.CosmosClient") as cosmos_client_mock: cosmos_client_mock.get_database_client = MagicMock(side_effect=Exception) with pytest.raises(UnableToAccessDatabase): - await BaseRepository.create(cosmos_client_mock, "container") + await BaseRepository.create("container") diff --git a/api_app/tests_ma/test_db/test_repositories/test_operation_repository.py b/api_app/tests_ma/test_db/test_repositories/test_operation_repository.py index 09bc96cd70..43cdf3c4f1 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_operation_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_operation_repository.py @@ -19,25 +19,22 @@ @pytest_asyncio.fixture async def operations_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - operations_repo = await OperationRepository.create(cosmos_client_mock) - yield operations_repo + operations_repo = await OperationRepository.create() + yield operations_repo @pytest_asyncio.fixture async def resource_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_repo = await ResourceRepository.create(cosmos_client_mock) - yield resource_repo + resource_repo = await ResourceRepository.create() + yield resource_repo @pytest_asyncio.fixture async def resource_template_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_template_repo = await ResourceTemplateRepository.create(cosmos_client_mock) - yield resource_template_repo + resource_template_repo = await ResourceTemplateRepository.create() + yield resource_template_repo @patch('uuid.uuid4', side_effect=["random-uuid-1", "random-uuid-2", "random-uuid-3"]) diff --git a/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py b/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py index aea50b9690..b4077e6a0b 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py @@ -17,9 +17,8 @@ @pytest_asyncio.fixture async def resource_history_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_history_repo = await ResourceHistoryRepository.create(cosmos_client_mock) - yield resource_history_repo + resource_history_repo = await ResourceHistoryRepository.create() + yield resource_history_repo @pytest.fixture diff --git a/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py b/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py index d0dec25234..62f9282112 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py @@ -27,17 +27,15 @@ @pytest_asyncio.fixture async def resource_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_repo = await ResourceRepository.create(cosmos_client_mock) - yield resource_repo + resource_repo = await ResourceRepository.create() + yield resource_repo @pytest_asyncio.fixture async def resource_history_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_history_repo = await ResourceHistoryRepository.create(cosmos_client_mock) - yield resource_history_repo + resource_history_repo = await ResourceHistoryRepository.create() + yield resource_history_repo @pytest.fixture diff --git a/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py b/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py index f84429b7c8..831f1fabb7 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py @@ -15,9 +15,8 @@ @pytest_asyncio.fixture async def resource_template_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_template_repo = await ResourceTemplateRepository.create(cosmos_client_mock) - yield resource_template_repo + resource_template_repo = await ResourceTemplateRepository.create() + yield resource_template_repo def sample_resource_template_as_dict(name: str, version: str = "1.0", resource_type: ResourceType = ResourceType.Workspace) -> ResourceTemplate: diff --git a/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py b/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py index 71da756dfd..f82c9ea710 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py @@ -18,17 +18,15 @@ @pytest_asyncio.fixture async def shared_service_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - shared_service_repo = await SharedServiceRepository.create(cosmos_client_mock) - yield shared_service_repo + shared_service_repo = await SharedServiceRepository.create() + yield shared_service_repo @pytest_asyncio.fixture async def operations_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - operations_repo = await OperationRepository.create(cosmos_client_mock) - yield operations_repo + operations_repo = await OperationRepository.create() + yield operations_repo @pytest.fixture diff --git a/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py b/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py index 936ebe50d7..8025936185 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py @@ -12,9 +12,8 @@ @pytest_asyncio.fixture async def resource_template_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_template_repo = await ResourceTemplateRepository.create(cosmos_client_mock) - yield resource_template_repo + resource_template_repo = await ResourceTemplateRepository.create() + yield resource_template_repo # Because shared service templates repository uses generic ResourceTemplate repository, most test cases are already covered diff --git a/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py b/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py index 7a2adbebde..e6854c8b62 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py @@ -27,9 +27,8 @@ def basic_user_resource_request(): @pytest_asyncio.fixture async def user_resource_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - user_resource_repo = await UserResourceRepository.create(cosmos_client_mock) - yield user_resource_repo + user_resource_repo = await UserResourceRepository.create() + yield user_resource_repo @pytest.fixture diff --git a/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py b/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py index 0cb5cb56fd..fc1ed642a5 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py @@ -14,9 +14,8 @@ @pytest_asyncio.fixture async def resource_template_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - resource_template_repo = await ResourceTemplateRepository.create(cosmos_client_mock) - yield resource_template_repo + resource_template_repo = await ResourceTemplateRepository.create() + yield resource_template_repo def sample_user_resource_template_as_dict(name: str, version: str = "1.0") -> dict: diff --git a/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py b/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py index 79b9a5176a..6aa551aaa8 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py @@ -21,17 +21,15 @@ def basic_workspace_request(): @pytest_asyncio.fixture async def workspace_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - workspace_repo = await WorkspaceRepository.create(cosmos_client_mock) - yield workspace_repo + workspace_repo = await WorkspaceRepository.create() + yield workspace_repo @pytest_asyncio.fixture async def operations_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - operations_repo = await OperationRepository.create(cosmos_client_mock) - yield operations_repo + operations_repo = await OperationRepository.create() + yield operations_repo @pytest.fixture diff --git a/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py b/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py index 10ce04e187..72a55ae9ba 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py @@ -19,17 +19,15 @@ @pytest_asyncio.fixture async def workspace_service_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - workspace_repo = await WorkspaceServiceRepository.create(cosmos_client_mock) - yield workspace_repo + workspace_repo = await WorkspaceServiceRepository.create() + yield workspace_repo @pytest_asyncio.fixture async def operations_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - operations_repo = await OperationRepository.create(cosmos_client_mock) - yield operations_repo + operations_repo = await OperationRepository.create() + yield operations_repo @pytest.fixture diff --git a/api_app/tests_ma/test_services/test_airlock.py b/api_app/tests_ma/test_services/test_airlock.py index f70f3f4003..b0ce1cb4e6 100644 --- a/api_app/tests_ma/test_services/test_airlock.py +++ b/api_app/tests_ma/test_services/test_airlock.py @@ -32,7 +32,7 @@ async def airlock_request_repo_mock(no_database): _ = no_database with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - airlock_request_repo_mock = await AirlockRequestRepository.create(cosmos_client_mock) + airlock_request_repo_mock = await AirlockRequestRepository.create() yield airlock_request_repo_mock From 7d85d046970c0a4686b193a2cca6467e9433dd46 Mon Sep 17 00:00:00 2001 From: marrobi Date: Tue, 2 Jan 2024 12:39:35 +0000 Subject: [PATCH 19/32] Remove uneeded cosmos references --- api_app/db/repositories/airlock_requests.py | 1 - api_app/tests_ma/test_api/test_routes/test_workspaces.py | 6 ++---- api_app/tests_ma/test_services/test_airlock.py | 5 ++--- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/api_app/db/repositories/airlock_requests.py b/api_app/db/repositories/airlock_requests.py index 0ec78f0452..f4a1926348 100644 --- a/api_app/db/repositories/airlock_requests.py +++ b/api_app/db/repositories/airlock_requests.py @@ -5,7 +5,6 @@ from typing import List, Optional from pydantic import UUID4 from azure.cosmos.exceptions import CosmosResourceNotFoundError, CosmosAccessConditionFailedError -from azure.cosmos.aio import CosmosClient from fastapi import HTTPException, status from pydantic import parse_obj_as from models.domain.authentication import User diff --git a/api_app/tests_ma/test_api/test_routes/test_workspaces.py b/api_app/tests_ma/test_api/test_routes/test_workspaces.py index 9b7ae7c18c..9279f2e26e 100644 --- a/api_app/tests_ma/test_api/test_routes/test_workspaces.py +++ b/api_app/tests_ma/test_api/test_routes/test_workspaces.py @@ -679,9 +679,8 @@ async def test_delete_workspace_returns_400_if_associated_workspace_services_are @ patch("api.dependencies.database.Database.get_repository") @ patch("api.dependencies.workspaces.WorkspaceRepository.get_workspace_by_id") @ patch("api.routes.workspaces.WorkspaceServiceRepository.get_active_workspace_services_for_workspace", return_value=[]) - @ patch('azure.cosmos.CosmosClient') @ patch('api.routes.resource_helpers.send_resource_request_message', return_value=sample_resource_operation(resource_id=WORKSPACE_ID, operation_id=OPERATION_ID)) - async def test_delete_workspace_sends_a_request_message_to_uninstall_the_workspace(self, send_request_message_mock, cosmos_client_mock, __, get_workspace_mock, get_repository_mock, resource_template_repo, ___, disabled_workspace, app, client, basic_resource_template): + async def test_delete_workspace_sends_a_request_message_to_uninstall_the_workspace(self, send_request_message_mock, __, get_workspace_mock, get_repository_mock, resource_template_repo, ___, disabled_workspace, app, client, basic_resource_template): get_workspace_mock.return_value = disabled_workspace get_repository_mock.side_effects = [await WorkspaceRepository.create(), await WorkspaceServiceRepository.create()] resource_template_repo.return_value = basic_resource_template @@ -695,8 +694,7 @@ async def test_delete_workspace_sends_a_request_message_to_uninstall_the_workspa @ patch("api.dependencies.database.Database.get_repository") @ patch("api.dependencies.workspaces.WorkspaceRepository.get_workspace_by_id") @ patch("api.routes.workspaces.WorkspaceServiceRepository.get_active_workspace_services_for_workspace") - @ patch('azure.cosmos.CosmosClient') - async def test_delete_workspace_raises_503_if_marking_the_resource_as_deleted_in_the_db_fails(self, __, ___, get_workspace_mock, _____, resource_template_repo, ______, client, app, disabled_workspace, basic_resource_template): + async def test_delete_workspace_raises_503_if_marking_the_resource_as_deleted_in_the_db_fails(self, ___, get_workspace_mock, _____, resource_template_repo, ______, client, app, disabled_workspace, basic_resource_template): get_workspace_mock.return_value = disabled_workspace resource_template_repo.return_value = basic_resource_template response = await client.delete(app.url_path_for(strings.API_DELETE_WORKSPACE, workspace_id=WORKSPACE_ID)) diff --git a/api_app/tests_ma/test_services/test_airlock.py b/api_app/tests_ma/test_services/test_airlock.py index b0ce1cb4e6..49eabeef1d 100644 --- a/api_app/tests_ma/test_services/test_airlock.py +++ b/api_app/tests_ma/test_services/test_airlock.py @@ -31,9 +31,8 @@ @pytest_asyncio.fixture async def airlock_request_repo_mock(no_database): _ = no_database - with patch('azure.cosmos.CosmosClient') as cosmos_client_mock: - airlock_request_repo_mock = await AirlockRequestRepository.create() - yield airlock_request_repo_mock + airlock_request_repo_mock = await AirlockRequestRepository.create() + yield airlock_request_repo_mock def sample_workspace(): From 7ab298f6c10c6f1600ec6ed30a2fafa3ffcd04f2 Mon Sep 17 00:00:00 2001 From: marrobi Date: Tue, 2 Jan 2024 12:54:50 +0000 Subject: [PATCH 20/32] Remove uneeded database imports --- api_app/service_bus/airlock_request_status_update.py | 1 - api_app/service_bus/deployment_status_updater.py | 1 - api_app/services/aad_authentication.py | 1 - 3 files changed, 3 deletions(-) diff --git a/api_app/service_bus/airlock_request_status_update.py b/api_app/service_bus/airlock_request_status_update.py index ad81402ebe..a01156bae9 100644 --- a/api_app/service_bus/airlock_request_status_update.py +++ b/api_app/service_bus/airlock_request_status_update.py @@ -6,7 +6,6 @@ from fastapi import HTTPException from pydantic import ValidationError, parse_obj_as -from api.dependencies.database import Database from api.dependencies.airlock import get_airlock_request_by_id_from_path from services.airlock import update_and_publish_event_airlock_request from services.logging import logger, tracer diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index de2a563e86..08b39d6b4e 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -4,7 +4,6 @@ from pydantic import ValidationError, parse_obj_as -from api.dependencies.database import Database from api.routes.resource_helpers import get_timestamp from models.domain.resource import Output from db.repositories.resources_history import ResourceHistoryRepository diff --git a/api_app/services/aad_authentication.py b/api_app/services/aad_authentication.py index bc39c48c87..70d8e19b45 100644 --- a/api_app/services/aad_authentication.py +++ b/api_app/services/aad_authentication.py @@ -14,7 +14,6 @@ from models.domain.authentication import User, RoleAssignment from models.domain.workspace import Workspace, WorkspaceRole from resources import strings -from api.dependencies.database import Database from db.repositories.workspaces import WorkspaceRepository from services.logging import logger From b815846119efd3c38ac22351da4e42f732577225 Mon Sep 17 00:00:00 2001 From: marrobi Date: Tue, 2 Jan 2024 15:17:36 +0000 Subject: [PATCH 21/32] fix async issue with FastAPI depends --- api_app/_version.py | 2 +- api_app/api/dependencies/airlock.py | 3 +- api_app/api/dependencies/database.py | 21 +---- api_app/api/dependencies/shared_services.py | 5 +- .../workspace_service_templates.py | 3 +- api_app/api/dependencies/workspaces.py | 13 ++- api_app/api/routes/airlock.py | 45 ++++++----- api_app/api/routes/api.py | 5 +- api_app/api/routes/costs.py | 11 ++- api_app/api/routes/migrations.py | 15 ++-- api_app/api/routes/operations.py | 3 +- .../api/routes/shared_service_templates.py | 7 +- api_app/api/routes/shared_services.py | 17 ++-- api_app/api/routes/user_resource_templates.py | 7 +- .../api/routes/workspace_service_templates.py | 7 +- api_app/api/routes/workspace_templates.py | 7 +- api_app/api/routes/workspaces.py | 79 +++++++++---------- api_app/db/migrations/airlock.py | 17 ++-- api_app/db/repositories/base.py | 25 +++++- .../test_routes/test_resource_helpers.py | 6 +- .../test_api/test_routes/test_workspaces.py | 4 +- .../test_repositories/test_base_repository.py | 2 +- .../test_resource_history_repository.py | 2 +- .../test_resource_repository.py | 4 +- .../test_resource_templates_repository.py | 2 +- .../test_shared_service_repository.py | 4 +- ...est_shared_service_templates_repository.py | 2 +- .../test_user_resource_repository.py | 2 +- ...test_user_resource_templates_repository.py | 2 +- .../test_workpaces_repository.py | 4 +- .../test_workpaces_service_repository.py | 4 +- .../test_deployment_status_update.py | 8 +- .../test_resource_request_sender.py | 22 +++--- .../tests_ma/test_services/test_airlock.py | 2 +- 34 files changed, 172 insertions(+), 190 deletions(-) diff --git a/api_app/_version.py b/api_app/_version.py index e61f7a55d6..391a39001a 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.18.4" +__version__ = "0.18.5" diff --git a/api_app/api/dependencies/airlock.py b/api_app/api/dependencies/airlock.py index efca378996..1a8ee75994 100644 --- a/api_app/api/dependencies/airlock.py +++ b/api_app/api/dependencies/airlock.py @@ -1,7 +1,6 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 -from api.dependencies.database import Database from db.repositories.airlock_requests import AirlockRequestRepository from models.domain.airlock_request import AirlockRequest from db.errors import EntityDoesNotExist, UnableToAccessDatabase @@ -17,5 +16,5 @@ async def get_airlock_request_by_id(airlock_request_id: UUID4, airlock_request_r raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=strings.STATE_STORE_ENDPOINT_NOT_RESPONDING) -async def get_airlock_request_by_id_from_path(airlock_request_id: UUID4 = Path(...), airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository))) -> AirlockRequest: +async def get_airlock_request_by_id_from_path(airlock_request_id: UUID4 = Path(...), airlock_request_repo=Depends(AirlockRequestRepository.get_repository())) -> AirlockRequest: return await get_airlock_request_by_id(airlock_request_id, airlock_request_repo) diff --git a/api_app/api/dependencies/database.py b/api_app/api/dependencies/database.py index 9c08caa8aa..a22d548f32 100644 --- a/api_app/api/dependencies/database.py +++ b/api_app/api/dependencies/database.py @@ -1,13 +1,8 @@ -from typing import Callable, Type - from azure.cosmos.aio import CosmosClient from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient -from fastapi import HTTPException, status + from core.config import MANAGED_IDENTITY_CLIENT_ID, STATE_STORE_ENDPOINT, STATE_STORE_KEY, STATE_STORE_SSL_VERIFY, SUBSCRIPTION_ID, RESOURCE_MANAGER_ENDPOINT, CREDENTIAL_SCOPES, RESOURCE_GROUP_NAME, COSMOSDB_ACCOUNT_NAME from core.credentials import get_credential_async -from db.errors import UnableToAccessDatabase -from db.repositories.base import BaseRepository -from resources import strings from services.logging import logger @@ -84,17 +79,3 @@ async def get_db_client(self) -> CosmosClient: if not Database.cosmos_client: Database.cosmos_client = await self._connect_to_db() return self.cosmos_client - - @classmethod - def get_repository(self, repo_type: Type[BaseRepository]) -> Callable: - async def _get_repo() -> BaseRepository: - try: - return await repo_type.create() - except UnableToAccessDatabase: - logger.exception(strings.STATE_STORE_ENDPOINT_NOT_RESPONDING) - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=strings.STATE_STORE_ENDPOINT_NOT_RESPONDING, - ) - - return _get_repo diff --git a/api_app/api/dependencies/shared_services.py b/api_app/api/dependencies/shared_services.py index 388ec8a3e5..f84a5e4b6a 100644 --- a/api_app/api/dependencies/shared_services.py +++ b/api_app/api/dependencies/shared_services.py @@ -1,7 +1,6 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 -from api.dependencies.database import Database from db.errors import EntityDoesNotExist from resources import strings from models.domain.shared_service import SharedService @@ -17,11 +16,11 @@ async def get_shared_service_by_id(shared_service_id: UUID4, shared_services_rep raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.SHARED_SERVICE_DOES_NOT_EXIST) -async def get_shared_service_by_id_from_path(shared_service_id: UUID4 = Path(...), shared_service_repo=Depends(Database().get_repository(SharedServiceRepository))) -> SharedService: +async def get_shared_service_by_id_from_path(shared_service_id: UUID4 = Path(...), shared_service_repo=Depends(SharedServiceRepository.get_repository())) -> SharedService: return await get_shared_service_by_id(shared_service_id, shared_service_repo) -async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(Database().get_repository(OperationRepository))) -> Operation: +async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(OperationRepository.get_repository())) -> Operation: try: return await operations_repo.get_operation_by_id(operation_id=operation_id) except EntityDoesNotExist: diff --git a/api_app/api/dependencies/workspace_service_templates.py b/api_app/api/dependencies/workspace_service_templates.py index 2a6e908722..b06864edcd 100644 --- a/api_app/api/dependencies/workspace_service_templates.py +++ b/api_app/api/dependencies/workspace_service_templates.py @@ -1,6 +1,5 @@ from fastapi import Depends, HTTPException, Path, status -from api.dependencies.database import Database from db.errors import EntityDoesNotExist from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType @@ -8,7 +7,7 @@ from resources import strings -async def get_workspace_service_template_by_name_from_path(service_template_name: str = Path(...), template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> ResourceTemplate: +async def get_workspace_service_template_by_name_from_path(service_template_name: str = Path(...), template_repo=Depends(ResourceTemplateRepository.get_repository())) -> ResourceTemplate: try: return await template_repo.get_current_template(service_template_name, ResourceType.WorkspaceService) except EntityDoesNotExist: diff --git a/api_app/api/dependencies/workspaces.py b/api_app/api/dependencies/workspaces.py index 90f98c63b5..56263b3938 100644 --- a/api_app/api/dependencies/workspaces.py +++ b/api_app/api/dependencies/workspaces.py @@ -1,7 +1,6 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 -from api.dependencies.database import Database from db.errors import EntityDoesNotExist, ResourceIsNotDeployed from db.repositories.operations import OperationRepository from db.repositories.user_resources import UserResourceRepository @@ -22,11 +21,11 @@ async def get_workspace_by_id(workspace_id: UUID4, workspaces_repo) -> Workspace raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.WORKSPACE_DOES_NOT_EXIST) -async def get_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(Database().get_repository(WorkspaceRepository))) -> Workspace: +async def get_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(WorkspaceRepository)) -> Workspace: return await get_workspace_by_id(workspace_id, workspaces_repo) -async def get_deployed_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(Database().get_repository(WorkspaceRepository)), operations_repo=Depends(Database().get_repository(OperationRepository))) -> Workspace: +async def get_deployed_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(WorkspaceRepository), operations_repo=Depends(OperationRepository.get_repository())) -> Workspace: try: return await workspaces_repo.get_deployed_workspace_by_id(workspace_id, operations_repo) except EntityDoesNotExist: @@ -35,14 +34,14 @@ async def get_deployed_workspace_by_id_from_path(workspace_id: UUID4 = Path(...) raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=strings.WORKSPACE_IS_NOT_DEPLOYED) -async def get_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(Database().get_repository(WorkspaceServiceRepository))) -> WorkspaceService: +async def get_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(WorkspaceServiceRepository)) -> WorkspaceService: try: return await workspace_services_repo.get_workspace_service_by_id(workspace_id, service_id) except EntityDoesNotExist: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.WORKSPACE_SERVICE_DOES_NOT_EXIST) -async def get_deployed_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), operations_repo=Depends(Database().get_repository(OperationRepository))) -> WorkspaceService: +async def get_deployed_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(WorkspaceServiceRepository), operations_repo=Depends(OperationRepository.get_repository())) -> WorkspaceService: try: return await workspace_services_repo.get_deployed_workspace_service_by_id(workspace_id, service_id, operations_repo) except EntityDoesNotExist: @@ -51,14 +50,14 @@ async def get_deployed_workspace_service_by_id_from_path(workspace_id: UUID4 = P raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=strings.WORKSPACE_SERVICE_IS_NOT_DEPLOYED) -async def get_user_resource_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), resource_id: UUID4 = Path(...), user_resource_repo=Depends(Database().get_repository(UserResourceRepository))) -> UserResource: +async def get_user_resource_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), resource_id: UUID4 = Path(...), user_resource_repo=Depends(UserResourceRepository)) -> UserResource: try: return await user_resource_repo.get_user_resource_by_id(workspace_id, service_id, resource_id) except EntityDoesNotExist: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.USER_RESOURCE_DOES_NOT_EXIST) -async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(Database().get_repository(OperationRepository))) -> Operation: +async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(OperationRepository.get_repository())) -> Operation: try: return await operations_repo.get_operation_by_id(operation_id=operation_id) except EntityDoesNotExist: diff --git a/api_app/api/routes/airlock.py b/api_app/api/routes/airlock.py index 900b418867..d7aa42b378 100644 --- a/api_app/api/routes/airlock.py +++ b/api_app/api/routes/airlock.py @@ -11,7 +11,6 @@ from db.repositories.airlock_requests import AirlockRequestRepository from db.errors import EntityDoesNotExist, UserNotAuthorizedToUseTemplate -from api.dependencies.database import Database from api.dependencies.workspaces import get_workspace_by_id_from_path, get_deployed_workspace_by_id_from_path from api.dependencies.airlock import get_airlock_request_by_id_from_path from models.domain.airlock_request import AirlockRequestStatus, AirlockRequestType @@ -36,7 +35,7 @@ response_model=AirlockRequestWithAllowedUserActions, name=strings.API_CREATE_AIRLOCK_REQUEST, dependencies=[Depends(get_current_workspace_owner_or_researcher_user), Depends(get_workspace_by_id_from_path)]) async def create_draft_request(airlock_request_input: AirlockRequestInCreate, user=Depends(get_current_workspace_owner_or_researcher_user), - airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), + airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), workspace=Depends(get_deployed_workspace_by_id_from_path)) -> AirlockRequestWithAllowedUserActions: if workspace.properties.get("enable_airlock") is False: raise HTTPException(status_code=status_code.HTTP_405_METHOD_NOT_ALLOWED, detail=strings.AIRLOCK_NOT_ENABLED_IN_WORKSPACE) @@ -57,7 +56,7 @@ async def create_draft_request(airlock_request_input: AirlockRequestInCreate, us dependencies=[Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) async def get_all_airlock_requests_by_workspace( - airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), + airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), workspace=Depends(get_deployed_workspace_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), creator_user_id: Optional[str] = None, type: Optional[AirlockRequestType] = None, status: Optional[AirlockRequestStatus] = None, @@ -77,7 +76,7 @@ async def get_all_airlock_requests_by_workspace( response_model=AirlockRequestWithAllowedUserActions, name=strings.API_GET_AIRLOCK_REQUEST, dependencies=[Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) async def retrieve_airlock_request_by_id(airlock_request=Depends(get_airlock_request_by_id_from_path), - airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), + airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager)) -> AirlockRequestWithAllowedUserActions: allowed_actions = get_allowed_actions(airlock_request, user, airlock_request_repo) return AirlockRequestWithAllowedUserActions(airlockRequest=airlock_request, allowedUserActions=allowed_actions) @@ -88,7 +87,7 @@ async def retrieve_airlock_request_by_id(airlock_request=Depends(get_airlock_req dependencies=[Depends(get_current_workspace_owner_or_researcher_user), Depends(get_workspace_by_id_from_path)]) async def create_submit_request(airlock_request=Depends(get_airlock_request_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user), - airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), + airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), workspace=Depends(get_workspace_by_id_from_path)) -> AirlockRequestWithAllowedUserActions: updated_request = await update_and_publish_event_airlock_request(airlock_request, airlock_request_repo, user, workspace, new_status=AirlockRequestStatus.Submitted) @@ -102,12 +101,12 @@ async def create_submit_request(airlock_request=Depends(get_airlock_request_by_i async def create_cancel_request(airlock_request=Depends(get_airlock_request_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user), workspace=Depends(get_workspace_by_id_from_path), - airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), - user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), - workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), - resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), - operation_repo=Depends(Database().get_repository(OperationRepository)), - resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)),) -> AirlockRequestWithAllowedUserActions: + airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), + user_resource_repo=Depends(UserResourceRepository.get_repository()), + workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), + resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), + operation_repo=Depends(OperationRepository.get_repository()), + resource_template_repo=Depends(ResourceTemplateRepository.get_repository()),) -> AirlockRequestWithAllowedUserActions: updated_request = await cancel_request(airlock_request, user, workspace, airlock_request_repo, user_resource_repo, workspace_service_repo, resource_template_repo, operation_repo, resource_history_repo) allowed_actions = get_allowed_actions(updated_request, user, airlock_request_repo) return AirlockRequestWithAllowedUserActions(airlockRequest=updated_request, allowedUserActions=allowed_actions) @@ -122,12 +121,12 @@ async def create_review_user_resource( airlock_request=Depends(get_airlock_request_by_id_from_path), user=Depends(get_current_airlock_manager_user), workspace=Depends(get_deployed_workspace_by_id_from_path), - user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), - workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), - operation_repo=Depends(Database().get_repository(OperationRepository)), - airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), - resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), - resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> AirlockRequestAndOperationInResponse: + user_resource_repo=Depends(UserResourceRepository.get_repository()), + workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), + operation_repo=Depends(OperationRepository.get_repository()), + airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), + resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), + resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> AirlockRequestAndOperationInResponse: if airlock_request.status != AirlockRequestStatus.InReview: raise HTTPException(status_code=status_code.HTTP_400_BAD_REQUEST, @@ -160,12 +159,12 @@ async def create_airlock_review( airlock_request=Depends(get_airlock_request_by_id_from_path), user=Depends(get_current_airlock_manager_user), workspace=Depends(get_deployed_workspace_by_id_from_path), - airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository)), - user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), - workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), - operation_repo=Depends(Database().get_repository(OperationRepository)), - resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), - resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> AirlockRequestWithAllowedUserActions: + airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), + user_resource_repo=Depends(UserResourceRepository.get_repository()), + workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), + operation_repo=Depends(OperationRepository.get_repository()), + resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), + resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> AirlockRequestWithAllowedUserActions: try: updated_airlock_request = await review_airlock_request(airlock_review_input, airlock_request, user, workspace, airlock_request_repo, user_resource_repo, workspace_service_repo, operation_repo, resource_template_repo, resource_history_repo) allowed_actions = get_allowed_actions(updated_airlock_request, user, airlock_request_repo) diff --git a/api_app/api/routes/api.py b/api_app/api/routes/api.py index 025717bc7a..99bf4b9ee9 100644 --- a/api_app/api/routes/api.py +++ b/api_app/api/routes/api.py @@ -5,7 +5,6 @@ from fastapi.openapi.docs import get_swagger_ui_html, get_swagger_ui_oauth2_redirect_html from fastapi.openapi.utils import get_openapi -from api.dependencies.database import Database from db.repositories.workspaces import WorkspaceRepository from api.routes import health, ping, workspaces, workspace_templates, workspace_service_templates, user_resource_templates, \ shared_services, shared_service_templates, migrations, costs, airlock, operations, metadata @@ -116,7 +115,7 @@ def get_scope(workspace) -> str: @workspace_swagger_router.get("/workspaces/{workspace_id}/openapi.json", include_in_schema=False, name="openapi_definitions") -async def get_openapi_json(workspace_id: str, request: Request, workspace_repo=Depends(Database().get_repository(WorkspaceRepository))): +async def get_openapi_json(workspace_id: str, request: Request, workspace_repo=Depends(WorkspaceRepository.get_repository())): global openapi_definitions if openapi_definitions[workspace_id] is None: @@ -146,7 +145,7 @@ async def get_openapi_json(workspace_id: str, request: Request, workspace_repo=D @workspace_swagger_router.get("/workspaces/{workspace_id}/docs", include_in_schema=False, name="workspace_swagger") -async def get_workspace_swagger(workspace_id, request: Request, workspace_repo=Depends(Database().get_repository(WorkspaceRepository))): +async def get_workspace_swagger(workspace_id, request: Request, workspace_repo=Depends(WorkspaceRepository.get_repository())): workspace = await workspace_repo.get_workspace_by_id(workspace_id) scope = get_scope(workspace) diff --git a/api_app/api/routes/costs.py b/api_app/api/routes/costs.py index 97d353bbb3..57e1352451 100644 --- a/api_app/api/routes/costs.py +++ b/api_app/api/routes/costs.py @@ -7,7 +7,6 @@ from pydantic import UUID4 from models.schemas.costs import get_cost_report_responses, get_workspace_cost_report_responses -from api.dependencies.database import Database from core import config from db.repositories.shared_services import SharedServiceRepository from db.repositories.user_resources import UserResourceRepository @@ -56,8 +55,8 @@ def __init__( async def costs( params: CostsQueryParams = Depends(), cost_service: CostService = Depends(cost_service_factory), - workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), - shared_services_repo=Depends(Database().get_repository(SharedServiceRepository))) -> CostReport: + workspace_repo=Depends(WorkspaceRepository.get_repository()), + shared_services_repo=Depends(SharedServiceRepository.get_repository())) -> CostReport: validate_report_period(params.from_date, params.to_date) try: @@ -90,9 +89,9 @@ async def costs( responses=get_workspace_cost_report_responses()) async def workspace_costs(workspace_id: UUID4, params: CostsQueryParams = Depends(), cost_service: CostService = Depends(cost_service_factory), - workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), - workspace_services_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), - user_resource_repo=Depends(Database().get_repository(UserResourceRepository))) -> WorkspaceCostReport: + workspace_repo=Depends(WorkspaceRepository.get_repository()), + workspace_services_repo=Depends(WorkspaceServiceRepository.get_repository()), + user_resource_repo=Depends(UserResourceRepository.get_repository())) -> WorkspaceCostReport: validate_report_period(params.from_date, params.to_date) try: diff --git a/api_app/api/routes/migrations.py b/api_app/api/routes/migrations.py index fcbd32e22b..cbb420fd1a 100644 --- a/api_app/api/routes/migrations.py +++ b/api_app/api/routes/migrations.py @@ -5,7 +5,6 @@ from db.repositories.resources_history import ResourceHistoryRepository from services.authentication import get_current_admin_user from resources import strings -from api.dependencies.database import Database from db.migrations.shared_services import SharedServiceMigration from db.migrations.workspaces import WorkspaceMigration from db.repositories.resources import ResourceRepository @@ -20,13 +19,13 @@ name=strings.API_MIGRATE_DATABASE, response_model=MigrationOutList, dependencies=[Depends(get_current_admin_user)]) -async def migrate_database(resources_repo=Depends(Database().get_repository(ResourceRepository)), - operations_repo=Depends(Database().get_repository(OperationRepository)), - resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), - shared_services_migration=Depends(Database().get_repository(SharedServiceMigration)), - workspace_migration=Depends(Database().get_repository(WorkspaceMigration)), - resource_migration=Depends(Database().get_repository(ResourceMigration)), - airlock_migration=Depends(Database().get_repository(AirlockMigration)),): +async def migrate_database(resources_repo=Depends(ResourceRepository.get_repository()), + operations_repo=Depends(OperationRepository.get_repository()), + resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), + shared_services_migration=Depends(SharedServiceMigration().get_repository()), + workspace_migration=Depends(WorkspaceMigration().get_repository()), + resource_migration=Depends(ResourceMigration().get_repository()), + airlock_migration=Depends(AirlockMigration().get_repository())): try: migrations = list() logger.info("PR 1030") diff --git a/api_app/api/routes/operations.py b/api_app/api/routes/operations.py index ee3c693a5b..0f2d1cc955 100644 --- a/api_app/api/routes/operations.py +++ b/api_app/api/routes/operations.py @@ -1,7 +1,6 @@ from fastapi import APIRouter, Depends from db.repositories.operations import OperationRepository -from api.dependencies.database import Database from models.schemas.operation import OperationInList from resources import strings from services.authentication import get_current_tre_user_or_tre_admin @@ -11,6 +10,6 @@ @operations_router.get("/operations", response_model=OperationInList, name=strings.API_GET_MY_OPERATIONS) -async def get_my_operations(user=Depends(get_current_tre_user_or_tre_admin), operations_repo=Depends(Database().get_repository(OperationRepository))) -> OperationInList: +async def get_my_operations(user=Depends(get_current_tre_user_or_tre_admin), operations_repo=Depends(OperationRepository.get_repository())) -> OperationInList: operations = await operations_repo.get_my_operations(user_id=user.id) return OperationInList(operations=operations) diff --git a/api_app/api/routes/shared_service_templates.py b/api_app/api/routes/shared_service_templates.py index a2c32f5aa6..57effd152c 100644 --- a/api_app/api/routes/shared_service_templates.py +++ b/api_app/api/routes/shared_service_templates.py @@ -2,7 +2,6 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as -from api.dependencies.database import Database from db.errors import EntityDoesNotExist, EntityVersionExist, InvalidInput from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType @@ -17,13 +16,13 @@ @shared_service_templates_core_router.get("/shared-service-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_SHARED_SERVICE_TEMPLATES) -async def get_shared_service_templates(authorized_only: bool = False, template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), user=Depends(get_current_tre_user_or_tre_admin)) -> ResourceTemplateInformationInList: +async def get_shared_service_templates(authorized_only: bool = False, template_repo=Depends(ResourceTemplateRepository.get_repository()), user=Depends(get_current_tre_user_or_tre_admin)) -> ResourceTemplateInformationInList: templates_infos = await template_repo.get_templates_information(ResourceType.SharedService, user.roles if authorized_only else None) return ResourceTemplateInformationInList(templates=templates_infos) @shared_service_templates_core_router.get("/shared-service-templates/{shared_service_template_name}", response_model=SharedServiceTemplateInResponse, response_model_exclude_none=True, name=strings.API_GET_SHARED_SERVICE_TEMPLATE_BY_NAME, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_shared_service_template(shared_service_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> SharedServiceTemplateInResponse: +async def get_shared_service_template(shared_service_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> SharedServiceTemplateInResponse: try: template = await get_template(shared_service_template_name, template_repo, ResourceType.SharedService, is_update=is_update, version=version) return parse_obj_as(SharedServiceTemplateInResponse, template) @@ -32,7 +31,7 @@ async def get_shared_service_template(shared_service_template_name: str, is_upda @shared_service_templates_core_router.post("/shared-service-templates", status_code=status.HTTP_201_CREATED, response_model=SharedServiceTemplateInResponse, response_model_exclude_none=True, name=strings.API_CREATE_SHARED_SERVICE_TEMPLATES, dependencies=[Depends(get_current_admin_user)]) -async def register_shared_service_template(template_input: SharedServiceTemplateInCreate, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> ResourceTemplateInResponse: +async def register_shared_service_template(template_input: SharedServiceTemplateInCreate, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> ResourceTemplateInResponse: try: return await template_repo.create_and_validate_template(template_input, ResourceType.SharedService) except EntityVersionExist: diff --git a/api_app/api/routes/shared_services.py b/api_app/api/routes/shared_services.py index baee1bbc02..2f7831b728 100644 --- a/api_app/api/routes/shared_services.py +++ b/api_app/api/routes/shared_services.py @@ -5,7 +5,6 @@ from db.repositories.operations import OperationRepository from db.errors import DuplicateEntity, MajorVersionUpdateDenied, UserNotAuthorizedToUseTemplate, TargetTemplateVersionDoesNotExist, VersionDowngradeDenied -from api.dependencies.database import Database from api.dependencies.shared_services import get_shared_service_by_id_from_path, get_operation_by_id_from_path from db.repositories.resource_templates import ResourceTemplateRepository from db.repositories.resources_history import ResourceHistoryRepository @@ -33,7 +32,7 @@ def user_is_tre_admin(user): @shared_services_router.get("/shared-services", response_model=SharedServicesInList, name=strings.API_GET_ALL_SHARED_SERVICES, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def retrieve_shared_services(shared_services_repo=Depends(Database().get_repository(SharedServiceRepository)), user=Depends(get_current_tre_user_or_tre_admin), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> SharedServicesInList: +async def retrieve_shared_services(shared_services_repo=Depends(SharedServiceRepository.get_repository()), user=Depends(get_current_tre_user_or_tre_admin), resource_template_repo=Depends(ResourceTemplateRepository.get_repository())) -> SharedServicesInList: shared_services = await shared_services_repo.get_active_shared_services() await asyncio.gather(*[enrich_resource_with_available_upgrades(shared_service, resource_template_repo) for shared_service in shared_services]) if user_is_tre_admin(user): @@ -43,7 +42,7 @@ async def retrieve_shared_services(shared_services_repo=Depends(Database().get_r @shared_services_router.get("/shared-services/{shared_service_id}", response_model=SharedServiceInResponse, name=strings.API_GET_SHARED_SERVICE_BY_ID, dependencies=[Depends(get_current_tre_user_or_tre_admin), Depends(get_shared_service_by_id_from_path)]) -async def retrieve_shared_service_by_id(shared_service=Depends(get_shared_service_by_id_from_path), user=Depends(get_current_tre_user_or_tre_admin), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository))): +async def retrieve_shared_service_by_id(shared_service=Depends(get_shared_service_by_id_from_path), user=Depends(get_current_tre_user_or_tre_admin), resource_template_repo=Depends(ResourceTemplateRepository.get_repository())): await enrich_resource_with_available_upgrades(shared_service, resource_template_repo) if user_is_tre_admin(user): return SharedServiceInResponse(sharedService=shared_service) @@ -52,7 +51,7 @@ async def retrieve_shared_service_by_id(shared_service=Depends(get_shared_servic @shared_services_router.post("/shared-services", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_CREATE_SHARED_SERVICE, dependencies=[Depends(get_current_admin_user)]) -async def create_shared_service(response: Response, shared_service_input: SharedServiceInCreate, user=Depends(get_current_admin_user), shared_services_repo=Depends(Database().get_repository(SharedServiceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def create_shared_service(response: Response, shared_service_input: SharedServiceInCreate, user=Depends(get_current_admin_user), shared_services_repo=Depends(SharedServiceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: try: shared_service, resource_template = await shared_services_repo.create_shared_service_item(shared_service_input, user.roles) except (ValidationError, ValueError) as e: @@ -83,7 +82,7 @@ async def create_shared_service(response: Response, shared_service_input: Shared response_model=OperationInResponse, name=strings.API_UPDATE_SHARED_SERVICE, dependencies=[Depends(get_current_admin_user), Depends(get_shared_service_by_id_from_path)]) -async def patch_shared_service(shared_service_patch: ResourcePatch, response: Response, user=Depends(get_current_admin_user), shared_service_repo=Depends(Database().get_repository(SharedServiceRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), shared_service=Depends(get_shared_service_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), etag: str = Header(...), force_version_update: bool = False) -> SharedServiceInResponse: +async def patch_shared_service(shared_service_patch: ResourcePatch, response: Response, user=Depends(get_current_admin_user), shared_service_repo=Depends(SharedServiceRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), shared_service=Depends(get_shared_service_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), etag: str = Header(...), force_version_update: bool = False) -> SharedServiceInResponse: try: patched_shared_service, _ = await shared_service_repo.patch_shared_service(shared_service, shared_service_patch, etag, resource_template_repo, resource_history_repo, user, force_version_update) operation = await send_resource_request_message( @@ -106,7 +105,7 @@ async def patch_shared_service(shared_service_patch: ResourcePatch, response: Re @shared_services_router.delete("/shared-services/{shared_service_id}", response_model=OperationInResponse, name=strings.API_DELETE_SHARED_SERVICE, dependencies=[Depends(get_current_admin_user)]) -async def delete_shared_service(response: Response, user=Depends(get_current_admin_user), shared_service=Depends(get_shared_service_by_id_from_path), operations_repo=Depends(Database().get_repository(OperationRepository)), shared_service_repo=Depends(Database().get_repository(SharedServiceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def delete_shared_service(response: Response, user=Depends(get_current_admin_user), shared_service=Depends(get_shared_service_by_id_from_path), operations_repo=Depends(OperationRepository.get_repository()), shared_service_repo=Depends(SharedServiceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: if shared_service.isEnabled: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=strings.SHARED_SERVICE_NEEDS_TO_BE_DISABLED_BEFORE_DELETION) @@ -125,7 +124,7 @@ async def delete_shared_service(response: Response, user=Depends(get_current_adm @shared_services_router.post("/shared-services/{shared_service_id}/invoke-action", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_INVOKE_ACTION_ON_SHARED_SERVICE, dependencies=[Depends(get_current_admin_user)]) -async def invoke_action_on_shared_service(response: Response, action: str, user=Depends(get_current_admin_user), shared_service=Depends(get_shared_service_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), shared_service_repo=Depends(Database().get_repository(SharedServiceRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def invoke_action_on_shared_service(response: Response, action: str, user=Depends(get_current_admin_user), shared_service=Depends(get_shared_service_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), shared_service_repo=Depends(SharedServiceRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: operation = await send_custom_action_message( resource=shared_service, resource_repo=shared_service_repo, @@ -143,7 +142,7 @@ async def invoke_action_on_shared_service(response: Response, action: str, user= # Shared service operations @shared_services_router.get("/shared-services/{shared_service_id}/operations", response_model=OperationInList, name=strings.API_GET_RESOURCE_OPERATIONS, dependencies=[Depends(get_current_admin_user), Depends(get_shared_service_by_id_from_path)]) -async def retrieve_shared_service_operations_by_shared_service_id(shared_service=Depends(get_shared_service_by_id_from_path), operations_repo=Depends(Database().get_repository(OperationRepository))) -> OperationInList: +async def retrieve_shared_service_operations_by_shared_service_id(shared_service=Depends(get_shared_service_by_id_from_path), operations_repo=Depends(OperationRepository.get_repository())) -> OperationInList: return OperationInList(operations=await operations_repo.get_operations_by_resource_id(resource_id=shared_service.id)) @@ -154,5 +153,5 @@ async def retrieve_shared_service_operation_by_shared_service_id_and_operation_i # Shared service history @shared_services_router.get("/shared-services/{shared_service_id}/history", response_model=ResourceHistoryInList, name=strings.API_GET_RESOURCE_HISTORY, dependencies=[Depends(get_current_admin_user)]) -async def retrieve_shared_service_history_by_shared_service_id(shared_service=Depends(get_shared_service_by_id_from_path), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: +async def retrieve_shared_service_history_by_shared_service_id(shared_service=Depends(get_shared_service_by_id_from_path), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> ResourceHistoryInList: return ResourceHistoryInList(resource_history=await resource_history_repo.get_resource_history_by_resource_id(resource_id=shared_service.id)) diff --git a/api_app/api/routes/user_resource_templates.py b/api_app/api/routes/user_resource_templates.py index 1f0d860652..4330a3b688 100644 --- a/api_app/api/routes/user_resource_templates.py +++ b/api_app/api/routes/user_resource_templates.py @@ -3,7 +3,6 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as -from api.dependencies.database import Database from api.dependencies.workspace_service_templates import get_workspace_service_template_by_name_from_path from api.routes.resource_helpers import get_template from db.errors import EntityVersionExist, InvalidInput @@ -19,19 +18,19 @@ @user_resource_templates_core_router.get("/workspace-service-templates/{service_template_name}/user-resource-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_USER_RESOURCE_TEMPLATES, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_user_resource_templates_for_service_template(service_template_name: str, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> ResourceTemplateInformationInList: +async def get_user_resource_templates_for_service_template(service_template_name: str, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> ResourceTemplateInformationInList: template_infos = await template_repo.get_templates_information(ResourceType.UserResource, parent_service_name=service_template_name) return ResourceTemplateInformationInList(templates=template_infos) @user_resource_templates_core_router.get("/workspace-service-templates/{service_template_name}/user-resource-templates/{user_resource_template_name}", response_model=UserResourceTemplateInResponse, response_model_exclude_none=True, name=strings.API_GET_USER_RESOURCE_TEMPLATE_BY_NAME, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_user_resource_template(service_template_name: str, user_resource_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> UserResourceTemplateInResponse: +async def get_user_resource_template(service_template_name: str, user_resource_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> UserResourceTemplateInResponse: template = await get_template(user_resource_template_name, template_repo, ResourceType.UserResource, service_template_name, is_update=is_update, version=version) return parse_obj_as(UserResourceTemplateInResponse, template) @user_resource_templates_core_router.post("/workspace-service-templates/{service_template_name}/user-resource-templates", status_code=status.HTTP_201_CREATED, response_model=UserResourceTemplateInResponse, response_model_exclude_none=True, name=strings.API_CREATE_USER_RESOURCE_TEMPLATES, dependencies=[Depends(get_current_admin_user)]) -async def register_user_resource_template(template_input: UserResourceTemplateInCreate, template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), workspace_service_template=Depends(get_workspace_service_template_by_name_from_path)) -> UserResourceTemplateInResponse: +async def register_user_resource_template(template_input: UserResourceTemplateInCreate, template_repo=Depends(ResourceTemplateRepository.get_repository()), workspace_service_template=Depends(get_workspace_service_template_by_name_from_path)) -> UserResourceTemplateInResponse: try: return await template_repo.create_and_validate_template(template_input, ResourceType.UserResource, workspace_service_template.name) except EntityVersionExist: diff --git a/api_app/api/routes/workspace_service_templates.py b/api_app/api/routes/workspace_service_templates.py index 2db35543c5..6411e15e1b 100644 --- a/api_app/api/routes/workspace_service_templates.py +++ b/api_app/api/routes/workspace_service_templates.py @@ -2,7 +2,6 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as -from api.dependencies.database import Database from api.routes.resource_helpers import get_template from db.errors import EntityVersionExist, InvalidInput from db.repositories.resource_templates import ResourceTemplateRepository @@ -17,19 +16,19 @@ @workspace_service_templates_core_router.get("/workspace-service-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_WORKSPACE_SERVICE_TEMPLATES, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_workspace_service_templates(template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> ResourceTemplateInformationInList: +async def get_workspace_service_templates(template_repo=Depends(ResourceTemplateRepository.get_repository())) -> ResourceTemplateInformationInList: templates_infos = await template_repo.get_templates_information(ResourceType.WorkspaceService) return ResourceTemplateInformationInList(templates=templates_infos) @workspace_service_templates_core_router.get("/workspace-service-templates/{service_template_name}", response_model=WorkspaceServiceTemplateInResponse, response_model_exclude_none=True, name=strings.API_GET_WORKSPACE_SERVICE_TEMPLATE_BY_NAME, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_workspace_service_template(service_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> WorkspaceServiceTemplateInResponse: +async def get_workspace_service_template(service_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> WorkspaceServiceTemplateInResponse: template = await get_template(service_template_name, template_repo, ResourceType.WorkspaceService, is_update=is_update, version=version) return parse_obj_as(WorkspaceServiceTemplateInResponse, template) @workspace_service_templates_core_router.post("/workspace-service-templates", status_code=status.HTTP_201_CREATED, response_model=WorkspaceServiceTemplateInResponse, response_model_exclude_none=True, name=strings.API_CREATE_WORKSPACE_SERVICE_TEMPLATES, dependencies=[Depends(get_current_admin_user)]) -async def register_workspace_service_template(template_input: WorkspaceServiceTemplateInCreate, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> ResourceTemplateInResponse: +async def register_workspace_service_template(template_input: WorkspaceServiceTemplateInCreate, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> ResourceTemplateInResponse: try: return await template_repo.create_and_validate_template(template_input, ResourceType.WorkspaceService) except EntityVersionExist: diff --git a/api_app/api/routes/workspace_templates.py b/api_app/api/routes/workspace_templates.py index 56d5c3153b..8dd0efda4c 100644 --- a/api_app/api/routes/workspace_templates.py +++ b/api_app/api/routes/workspace_templates.py @@ -2,7 +2,6 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as -from api.dependencies.database import Database from db.errors import EntityVersionExist, InvalidInput from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType @@ -17,19 +16,19 @@ @workspace_templates_admin_router.get("/workspace-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_WORKSPACE_TEMPLATES) -async def get_workspace_templates(authorized_only: bool = False, template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), user=Depends(get_current_admin_user)) -> ResourceTemplateInformationInList: +async def get_workspace_templates(authorized_only: bool = False, template_repo=Depends(ResourceTemplateRepository.get_repository()), user=Depends(get_current_admin_user)) -> ResourceTemplateInformationInList: templates_infos = await template_repo.get_templates_information(ResourceType.Workspace, user.roles if authorized_only else None) return ResourceTemplateInformationInList(templates=templates_infos) @workspace_templates_admin_router.get("/workspace-templates/{workspace_template_name}", response_model=WorkspaceTemplateInResponse, name=strings.API_GET_WORKSPACE_TEMPLATE_BY_NAME, response_model_exclude_none=True) -async def get_workspace_template(workspace_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> WorkspaceTemplateInResponse: +async def get_workspace_template(workspace_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> WorkspaceTemplateInResponse: template = await get_template(workspace_template_name, template_repo, ResourceType.Workspace, is_update=is_update, version=version) return parse_obj_as(WorkspaceTemplateInResponse, template) @workspace_templates_admin_router.post("/workspace-templates", status_code=status.HTTP_201_CREATED, response_model=WorkspaceTemplateInResponse, response_model_exclude_none=True, name=strings.API_CREATE_WORKSPACE_TEMPLATES) -async def register_workspace_template(template_input: WorkspaceTemplateInCreate, template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> ResourceTemplateInResponse: +async def register_workspace_template(template_input: WorkspaceTemplateInCreate, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> ResourceTemplateInResponse: try: return await template_repo.create_and_validate_template(template_input, ResourceType.Workspace) except EntityVersionExist: diff --git a/api_app/api/routes/workspaces.py b/api_app/api/routes/workspaces.py index a0a7da34eb..9c2d38caf6 100644 --- a/api_app/api/routes/workspaces.py +++ b/api_app/api/routes/workspaces.py @@ -4,7 +4,6 @@ from jsonschema.exceptions import ValidationError -from api.dependencies.database import Database from api.dependencies.workspaces import get_operation_by_id_from_path, get_workspace_by_id_from_path, get_deployed_workspace_by_id_from_path, get_deployed_workspace_service_by_id_from_path, get_workspace_service_by_id_from_path, get_user_resource_by_id_from_path from db.errors import InvalidInput, MajorVersionUpdateDenied, TargetTemplateVersionDoesNotExist, UserNotAuthorizedToUseTemplate, VersionDowngradeDenied from db.repositories.operations import OperationRepository @@ -56,7 +55,7 @@ def validate_user_has_valid_role_for_user_resource(user, user_resource): # WORKSPACE ROUTES @workspaces_core_router.get("/workspaces", response_model=WorkspacesInList, name=strings.API_GET_ALL_WORKSPACES) -async def retrieve_users_active_workspaces(request: Request, user=Depends(get_current_tre_user_or_tre_admin), workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> WorkspacesInList: +async def retrieve_users_active_workspaces(request: Request, user=Depends(get_current_tre_user_or_tre_admin), workspace_repo=Depends(WorkspaceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository())) -> WorkspacesInList: try: user = await get_current_admin_user(request) @@ -83,7 +82,7 @@ def _safe_get_workspace_role(user, workspace, user_role_assignments): @workspaces_shared_router.get("/workspaces/{workspace_id}", response_model=WorkspaceInResponse, name=strings.API_GET_WORKSPACE_BY_ID) -async def retrieve_workspace_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> WorkspaceInResponse: +async def retrieve_workspace_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository())) -> WorkspaceInResponse: await enrich_resource_with_available_upgrades(workspace, resource_template_repo) return WorkspaceInResponse(workspace=workspace) @@ -97,7 +96,7 @@ async def retrieve_workspace_scope_id_by_workspace_id(workspace=Depends(get_work @workspaces_core_router.post("/workspaces", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_CREATE_WORKSPACE, dependencies=[Depends(get_current_admin_user)]) -async def create_workspace(workspace_create: WorkspaceInCreate, response: Response, user=Depends(get_current_admin_user), workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def create_workspace(workspace_create: WorkspaceInCreate, response: Response, user=Depends(get_current_admin_user), workspace_repo=Depends(WorkspaceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: try: # TODO: This requires Directory.ReadAll ( Application.Read.All ) to be enabled in the Azure AD application to enable a users workspaces to be listed. This should be made optional. auth_info = extract_auth_information(workspace_create.properties) @@ -125,7 +124,7 @@ async def create_workspace(workspace_create: WorkspaceInCreate, response: Respon @workspaces_core_router.patch("/workspaces/{workspace_id}", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_UPDATE_WORKSPACE, dependencies=[Depends(get_current_admin_user)]) -async def patch_workspace(resource_patch: ResourcePatch, response: Response, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), workspace_repo: WorkspaceRepository = Depends(Database().get_repository(WorkspaceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: +async def patch_workspace(resource_patch: ResourcePatch, response: Response, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), workspace_repo: WorkspaceRepository = Depends(WorkspaceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: try: is_disablement = resource_patch.isEnabled is not None and not resource_patch.isEnabled if is_disablement: @@ -153,7 +152,7 @@ async def patch_workspace(resource_patch: ResourcePatch, response: Response, use @workspaces_core_router.delete("/workspaces/{workspace_id}", response_model=OperationInResponse, name=strings.API_DELETE_WORKSPACE, dependencies=[Depends(get_current_admin_user)]) -async def delete_workspace(response: Response, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), operations_repo=Depends(Database().get_repository(OperationRepository)), workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def delete_workspace(response: Response, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), operations_repo=Depends(OperationRepository.get_repository()), workspace_repo=Depends(WorkspaceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: if await delete_validation(workspace, workspace_repo): operation = await send_uninstall_message( resource=workspace, @@ -171,7 +170,7 @@ async def delete_workspace(response: Response, user=Depends(get_current_admin_us @workspaces_core_router.post("/workspaces/{workspace_id}/invoke-action", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_INVOKE_ACTION_ON_WORKSPACE, dependencies=[Depends(get_current_admin_user)]) -async def invoke_action_on_workspace(response: Response, action: str, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def invoke_action_on_workspace(response: Response, action: str, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), workspace_repo=Depends(WorkspaceRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: operation = await send_custom_action_message( resource=workspace, resource_repo=workspace_repo, @@ -192,7 +191,7 @@ async def invoke_action_on_workspace(response: Response, action: str, user=Depen @workspaces_shared_router.get("/workspaces/{workspace_id}/workspace-service-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_WORKSPACE_SERVICE_TEMPLATES_IN_WORKSPACE) async def get_workspace_service_templates( workspace=Depends(get_workspace_by_id_from_path), - template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), + template_repo=Depends(ResourceTemplateRepository.get_repository()), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager_or_tre_admin)) -> ResourceTemplateInformationInList: template_infos = await template_repo.get_templates_information(ResourceType.WorkspaceService, user.roles) return ResourceTemplateInformationInList(templates=template_infos) @@ -203,14 +202,14 @@ async def get_workspace_service_templates( async def get_user_resource_templates( service_template_name: str, workspace=Depends(get_workspace_by_id_from_path), - template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), + template_repo=Depends(ResourceTemplateRepository.get_repository()), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager_or_tre_admin)) -> ResourceTemplateInformationInList: template_infos = await template_repo.get_templates_information(ResourceType.UserResource, user.roles, service_template_name) return ResourceTemplateInformationInList(templates=template_infos) @workspaces_shared_router.get("/workspaces/{workspace_id}/operations", response_model=OperationInList, name=strings.API_GET_RESOURCE_OPERATIONS, dependencies=[Depends(get_current_workspace_owner_or_tre_admin)]) -async def retrieve_workspace_operations_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), operations_repo=Depends(Database().get_repository(OperationRepository))) -> OperationInList: +async def retrieve_workspace_operations_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), operations_repo=Depends(OperationRepository.get_repository())) -> OperationInList: return OperationInList(operations=await operations_repo.get_operations_by_resource_id(resource_id=workspace.id)) @@ -220,26 +219,26 @@ async def retrieve_workspace_operation_by_workspace_id_and_operation_id(workspac @workspaces_shared_router.get("/workspaces/{workspace_id}/history", response_model=ResourceHistoryInList, name=strings.API_GET_RESOURCE_HISTORY, dependencies=[Depends(get_current_workspace_owner_or_tre_admin)]) -async def retrieve_workspace_history_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: +async def retrieve_workspace_history_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> ResourceHistoryInList: return ResourceHistoryInList(resource_history=await resource_history_repo.get_resource_history_by_resource_id(resource_id=workspace.id)) # WORKSPACE SERVICES ROUTES @workspace_services_workspace_router.get("/workspaces/{workspace_id}/workspace-services", response_model=WorkspaceServicesInList, name=strings.API_GET_ALL_WORKSPACE_SERVICES, dependencies=[Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager)]) -async def retrieve_users_active_workspace_services(workspace=Depends(get_workspace_by_id_from_path), workspace_services_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> WorkspaceServicesInList: +async def retrieve_users_active_workspace_services(workspace=Depends(get_workspace_by_id_from_path), workspace_services_repo=Depends(WorkspaceServiceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository())) -> WorkspaceServicesInList: workspace_services = await workspace_services_repo.get_active_workspace_services_for_workspace(workspace.id) await asyncio.gather(*[enrich_resource_with_available_upgrades(workspace_service, resource_template_repo) for workspace_service in workspace_services]) return WorkspaceServicesInList(workspaceServices=workspace_services) @workspace_services_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}", response_model=WorkspaceServiceInResponse, name=strings.API_GET_WORKSPACE_SERVICE_BY_ID, dependencies=[Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) -async def retrieve_workspace_service_by_id(workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> WorkspaceServiceInResponse: +async def retrieve_workspace_service_by_id(workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository())) -> WorkspaceServiceInResponse: await enrich_resource_with_available_upgrades(workspace_service, resource_template_repo) return WorkspaceServiceInResponse(workspaceService=workspace_service) @workspace_services_workspace_router.post("/workspaces/{workspace_id}/workspace-services", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_CREATE_WORKSPACE_SERVICE, dependencies=[Depends(get_current_workspace_owner_user)]) -async def create_workspace_service(response: Response, workspace_service_input: WorkspaceServiceInCreate, user=Depends(get_current_workspace_owner_user), workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), workspace_repo=Depends(Database().get_repository(WorkspaceRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), workspace=Depends(get_deployed_workspace_by_id_from_path)) -> OperationInResponse: +async def create_workspace_service(response: Response, workspace_service_input: WorkspaceServiceInCreate, user=Depends(get_current_workspace_owner_user), workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), workspace_repo=Depends(WorkspaceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), workspace=Depends(get_deployed_workspace_by_id_from_path)) -> OperationInResponse: try: workspace_service, resource_template = await workspace_service_repo.create_workspace_service_item(workspace_service_input, workspace.id, user.roles) @@ -280,7 +279,7 @@ async def create_workspace_service(response: Response, workspace_service_input: @workspace_services_workspace_router.patch("/workspaces/{workspace_id}/workspace-services/{service_id}", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_UPDATE_WORKSPACE_SERVICE, dependencies=[Depends(get_current_workspace_owner_or_researcher_user), Depends(get_workspace_by_id_from_path)]) -async def patch_workspace_service(resource_patch: ResourcePatch, response: Response, user=Depends(get_current_workspace_owner_user), workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: +async def patch_workspace_service(resource_patch: ResourcePatch, response: Response, user=Depends(get_current_workspace_owner_user), workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: try: is_disablement = resource_patch.isEnabled is not None and not resource_patch.isEnabled if is_disablement: @@ -306,7 +305,7 @@ async def patch_workspace_service(resource_patch: ResourcePatch, response: Respo @workspace_services_workspace_router.delete("/workspaces/{workspace_id}/workspace-services/{service_id}", response_model=OperationInResponse, name=strings.API_DELETE_WORKSPACE_SERVICE, dependencies=[Depends(get_current_workspace_owner_user)]) -async def delete_workspace_service(response: Response, user=Depends(get_current_workspace_owner_user), workspace=Depends(get_workspace_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def delete_workspace_service(response: Response, user=Depends(get_current_workspace_owner_user), workspace=Depends(get_workspace_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), user_resource_repo=Depends(UserResourceRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: if await delete_validation(workspace_service, workspace_service_repo): operation = await send_uninstall_message( resource=workspace_service, @@ -324,7 +323,7 @@ async def delete_workspace_service(response: Response, user=Depends(get_current_ @workspace_services_workspace_router.post("/workspaces/{workspace_id}/workspace-services/{service_id}/invoke-action", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_INVOKE_ACTION_ON_WORKSPACE_SERVICE, dependencies=[Depends(get_current_workspace_owner_user)]) -async def invoke_action_on_workspace_service(response: Response, action: str, user=Depends(get_current_workspace_owner_user), workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), operations_repo=Depends(Database().get_repository(OperationRepository)), workspace_service_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: +async def invoke_action_on_workspace_service(response: Response, action: str, user=Depends(get_current_workspace_owner_user), workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: operation = await send_custom_action_message( resource=workspace_service, resource_repo=workspace_service_repo, @@ -342,7 +341,7 @@ async def invoke_action_on_workspace_service(response: Response, action: str, us # workspace service operations @workspace_services_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}/operations", response_model=OperationInList, name=strings.API_GET_RESOURCE_OPERATIONS, dependencies=[Depends(get_current_workspace_owner_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) -async def retrieve_workspace_service_operations_by_workspace_service_id(workspace_service=Depends(get_workspace_service_by_id_from_path), operations_repo=Depends(Database().get_repository(OperationRepository))) -> OperationInList: +async def retrieve_workspace_service_operations_by_workspace_service_id(workspace_service=Depends(get_workspace_service_by_id_from_path), operations_repo=Depends(OperationRepository.get_repository())) -> OperationInList: return OperationInList(operations=await operations_repo.get_operations_by_resource_id(resource_id=workspace_service.id)) @@ -352,7 +351,7 @@ async def retrieve_workspace_service_operation_by_workspace_service_id_and_opera @workspace_services_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}/history", response_model=ResourceHistoryInList, name=strings.API_GET_RESOURCE_HISTORY, dependencies=[Depends(get_current_workspace_owner_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) -async def retrieve_workspace_service_history_by_workspace_service_id(workspace_service=Depends(get_workspace_service_by_id_from_path), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: +async def retrieve_workspace_service_history_by_workspace_service_id(workspace_service=Depends(get_workspace_service_by_id_from_path), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> ResourceHistoryInList: return ResourceHistoryInList(resource_history=await resource_history_repo.get_resource_history_by_resource_id(resource_id=workspace_service.id)) @@ -362,8 +361,8 @@ async def retrieve_user_resources_for_workspace_service( workspace_id: str, service_id: str, user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), - resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), - user_resource_repo=Depends(Database().get_repository(UserResourceRepository))) -> UserResourcesInList: + resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), + user_resource_repo=Depends(UserResourceRepository.get_repository())) -> UserResourcesInList: user_resources = await user_resource_repo.get_user_resources_for_workspace_service(workspace_id, service_id) # filter only to the user - for researchers @@ -382,7 +381,7 @@ async def retrieve_user_resources_for_workspace_service( @user_resources_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}/user-resources/{resource_id}", response_model=UserResourceInResponse, name=strings.API_GET_USER_RESOURCE, dependencies=[Depends(get_workspace_by_id_from_path)]) async def retrieve_user_resource_by_id( user_resource=Depends(get_user_resource_by_id_from_path), - resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), + resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager)) -> UserResourceInResponse: validate_user_has_valid_role_for_user_resource(user, user_resource) @@ -397,10 +396,10 @@ async def retrieve_user_resource_by_id( async def create_user_resource( response: Response, user_resource_create: UserResourceInCreate, - user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), - resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), - operations_repo=Depends(Database().get_repository(OperationRepository)), - resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), + user_resource_repo=Depends(UserResourceRepository.get_repository()), + resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), + operations_repo=Depends(OperationRepository.get_repository()), + resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), workspace=Depends(get_deployed_workspace_by_id_from_path), workspace_service=Depends(get_deployed_workspace_service_by_id_from_path)) -> OperationInResponse: @@ -433,10 +432,10 @@ async def delete_user_resource( user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), user_resource=Depends(get_user_resource_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), - user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), - operations_repo=Depends(Database().get_repository(OperationRepository)), - resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), - resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> OperationInResponse: + user_resource_repo=Depends(UserResourceRepository.get_repository()), + operations_repo=Depends(OperationRepository.get_repository()), + resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), + resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: validate_user_has_valid_role_for_user_resource(user, user_resource) if user_resource.isEnabled: @@ -463,10 +462,10 @@ async def patch_user_resource( user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), user_resource=Depends(get_user_resource_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), - user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), - resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), - resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), - operations_repo=Depends(Database().get_repository(OperationRepository)), + user_resource_repo=Depends(UserResourceRepository.get_repository()), + resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), + resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), + operations_repo=Depends(OperationRepository.get_repository()), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: validate_user_has_valid_role_for_user_resource(user, user_resource) @@ -490,10 +489,10 @@ async def invoke_action_on_user_resource( action: str, user_resource=Depends(get_user_resource_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), - resource_template_repo=Depends(Database().get_repository(ResourceTemplateRepository)), - user_resource_repo=Depends(Database().get_repository(UserResourceRepository)), - operations_repo=Depends(Database().get_repository(OperationRepository)), - resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository)), + resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), + user_resource_repo=Depends(UserResourceRepository.get_repository()), + operations_repo=Depends(OperationRepository.get_repository()), + resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager)) -> OperationInResponse: validate_user_has_valid_role_for_user_resource(user, user_resource) operation = await send_custom_action_message( @@ -517,7 +516,7 @@ async def invoke_action_on_user_resource( async def retrieve_user_resource_operations_by_user_resource_id( user_resource=Depends(get_user_resource_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), - operations_repo=Depends(Database().get_repository(OperationRepository))) -> OperationInList: + operations_repo=Depends(OperationRepository.get_repository())) -> OperationInList: validate_user_has_valid_role_for_user_resource(user, user_resource) return OperationInList(operations=await operations_repo.get_operations_by_resource_id(resource_id=user_resource.id)) @@ -532,6 +531,6 @@ async def retrieve_user_resource_operations_by_user_resource_id_and_operation_id @user_resources_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}/user-resources/{resource_id}/history", response_model=ResourceHistoryInList, name=strings.API_GET_RESOURCE_HISTORY, dependencies=[Depends(get_workspace_by_id_from_path)]) -async def retrieve_user_resource_history_by_user_resource_id(user_resource=Depends(get_user_resource_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), resource_history_repo=Depends(Database().get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: +async def retrieve_user_resource_history_by_user_resource_id(user_resource=Depends(get_user_resource_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> ResourceHistoryInList: validate_user_has_valid_role_for_user_resource(user, user_resource) return ResourceHistoryInList(resource_history=await resource_history_repo.get_resource_history_by_resource_id(resource_id=user_resource.id)) diff --git a/api_app/db/migrations/airlock.py b/api_app/db/migrations/airlock.py index 9c834aba07..aa0ce011cd 100644 --- a/api_app/db/migrations/airlock.py +++ b/api_app/db/migrations/airlock.py @@ -40,15 +40,14 @@ async def change_review_resources_to_dict(self) -> int: num_updated = 0 for request in await self.query('SELECT * FROM c'): # Only migrate if airlockReviewResources property present and is a list - if 'reviewUserResources' in request: - if type(request['reviewUserResources']) == list: - updated_review_resources = {} - for i, resource in enumerate(request['reviewUserResources']): - updated_review_resources['UNKNOWN' + str(i)] = resource - - request['reviewUserResources'] = updated_review_resources - await self.update_item_dict(request) - num_updated += 1 + if 'reviewUserResources' in request and isinstance(request['reviewUserResources'], list): + updated_review_resources = {} + for i, resource in enumerate(request['reviewUserResources']): + updated_review_resources['UNKNOWN' + str(i)] = resource + + request['reviewUserResources'] = updated_review_resources + await self.update_item_dict(request) + num_updated += 1 return num_updated diff --git a/api_app/db/repositories/base.py b/api_app/db/repositories/base.py index cc1bc78385..9ac8af6eb0 100644 --- a/api_app/db/repositories/base.py +++ b/api_app/db/repositories/base.py @@ -1,10 +1,13 @@ -from typing import Optional +from typing import Callable, Optional from azure.cosmos.aio import ContainerProxy from azure.core import MatchConditions +from fastapi import HTTPException, status from pydantic import BaseModel -from core import config +from api.dependencies.database import Database +from resources.strings import STATE_STORE_ENDPOINT_NOT_RESPONDING from db.errors import UnableToAccessDatabase +from services.logging import logger class BaseRepository: @@ -13,6 +16,20 @@ async def create(cls, container_name: Optional[str] = None): cls._container: ContainerProxy = await cls._get_container(container_name) return cls + @classmethod + def get_repository(cls) -> Callable: + async def _get_repo() -> BaseRepository: + try: + return await cls.create() + except UnableToAccessDatabase: + logger.exception(STATE_STORE_ENDPOINT_NOT_RESPONDING) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=STATE_STORE_ENDPOINT_NOT_RESPONDING, + ) + + return _get_repo + @property def container(self) -> ContainerProxy: return self._container @@ -20,8 +37,8 @@ def container(self) -> ContainerProxy: @classmethod async def _get_container(cls, container_name) -> ContainerProxy: try: - database = cls._client.get_database_client(config.STATE_STORE_DATABASE) - container = database.get_container_client(container=container_name) + database = await Database().get_db_client() + container = await database.create_container_if_not_exists(id=container_name) return container except Exception: raise UnableToAccessDatabase diff --git a/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py b/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py index 34bf8659a1..b7f61a6345 100644 --- a/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py +++ b/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py @@ -30,19 +30,19 @@ @pytest_asyncio.fixture async def resource_repo() -> ResourceRepository: with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): - resource_repo_mock = await ResourceRepository.create() + resource_repo_mock = await ResourceRepository().create() yield resource_repo_mock @pytest_asyncio.fixture async def operations_repo() -> OperationRepository: - operation_repo_mock = await OperationRepository.create() + operation_repo_mock = await OperationRepository().create() yield operation_repo_mock @pytest_asyncio.fixture async def resource_history_repo() -> ResourceHistoryRepository: - resource_history_repo_mock = await ResourceHistoryRepository.create() + resource_history_repo_mock = await ResourceHistoryRepository().create() yield resource_history_repo_mock diff --git a/api_app/tests_ma/test_api/test_routes/test_workspaces.py b/api_app/tests_ma/test_api/test_routes/test_workspaces.py index 9279f2e26e..f948a13195 100644 --- a/api_app/tests_ma/test_api/test_routes/test_workspaces.py +++ b/api_app/tests_ma/test_api/test_routes/test_workspaces.py @@ -676,7 +676,7 @@ async def test_delete_workspace_returns_400_if_associated_workspace_services_are # [DELETE] /workspaces/{workspace_id} @ patch("api.routes.resource_helpers.ResourceRepository.get_resource_dependency_list") @ patch("api.routes.workspaces.ResourceTemplateRepository.get_template_by_name_and_version") - @ patch("api.dependencies.database.Database.get_repository") + @ patch("db.repositories.base.BaseRepository.get_repository") @ patch("api.dependencies.workspaces.WorkspaceRepository.get_workspace_by_id") @ patch("api.routes.workspaces.WorkspaceServiceRepository.get_active_workspace_services_for_workspace", return_value=[]) @ patch('api.routes.resource_helpers.send_resource_request_message', return_value=sample_resource_operation(resource_id=WORKSPACE_ID, operation_id=OPERATION_ID)) @@ -691,7 +691,7 @@ async def test_delete_workspace_sends_a_request_message_to_uninstall_the_workspa # [DELETE] /workspaces/{workspace_id} @ patch("api.routes.resource_helpers.ResourceRepository.get_resource_dependency_list") @ patch("api.routes.workspaces.ResourceTemplateRepository.get_template_by_name_and_version") - @ patch("api.dependencies.database.Database.get_repository") + @ patch("db.repositories.base.BaseRepository.get_repository") @ patch("api.dependencies.workspaces.WorkspaceRepository.get_workspace_by_id") @ patch("api.routes.workspaces.WorkspaceServiceRepository.get_active_workspace_services_for_workspace") async def test_delete_workspace_raises_503_if_marking_the_resource_as_deleted_in_the_db_fails(self, ___, get_workspace_mock, _____, resource_template_repo, ______, client, app, disabled_workspace, basic_resource_template): diff --git a/api_app/tests_ma/test_db/test_repositories/test_base_repository.py b/api_app/tests_ma/test_db/test_repositories/test_base_repository.py index 4d8c859edb..09c7cbf0bf 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_base_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_base_repository.py @@ -10,6 +10,6 @@ async def test_instantiating_a_repo_raises_unable_to_access_database_if_database_cant_be_accessed(): with patch("azure.cosmos.CosmosClient") as cosmos_client_mock: - cosmos_client_mock.get_database_client = MagicMock(side_effect=Exception) + cosmos_client_mock.create_container_if_not_exists = MagicMock(side_effect=Exception) with pytest.raises(UnableToAccessDatabase): await BaseRepository.create("container") diff --git a/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py b/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py index b4077e6a0b..4439911752 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py @@ -17,7 +17,7 @@ @pytest_asyncio.fixture async def resource_history_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - resource_history_repo = await ResourceHistoryRepository.create() + resource_history_repo = await ResourceHistoryRepository().create() yield resource_history_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py b/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py index 62f9282112..c929469017 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py @@ -27,14 +27,14 @@ @pytest_asyncio.fixture async def resource_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - resource_repo = await ResourceRepository.create() + resource_repo = await ResourceRepository().create() yield resource_repo @pytest_asyncio.fixture async def resource_history_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - resource_history_repo = await ResourceHistoryRepository.create() + resource_history_repo = await ResourceHistoryRepository().create() yield resource_history_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py b/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py index 831f1fabb7..20fa5f30fe 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py @@ -15,7 +15,7 @@ @pytest_asyncio.fixture async def resource_template_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - resource_template_repo = await ResourceTemplateRepository.create() + resource_template_repo = await ResourceTemplateRepository().create() yield resource_template_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py b/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py index f82c9ea710..95014a53f7 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py @@ -18,14 +18,14 @@ @pytest_asyncio.fixture async def shared_service_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): - shared_service_repo = await SharedServiceRepository.create() + shared_service_repo = await SharedServiceRepository().create() yield shared_service_repo @pytest_asyncio.fixture async def operations_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - operations_repo = await OperationRepository.create() + operations_repo = await OperationRepository().create() yield operations_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py b/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py index 8025936185..e9d095bf01 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py @@ -12,7 +12,7 @@ @pytest_asyncio.fixture async def resource_template_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - resource_template_repo = await ResourceTemplateRepository.create() + resource_template_repo = await ResourceTemplateRepository().create() yield resource_template_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py b/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py index e6854c8b62..034b8155b3 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py @@ -27,7 +27,7 @@ def basic_user_resource_request(): @pytest_asyncio.fixture async def user_resource_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - user_resource_repo = await UserResourceRepository.create() + user_resource_repo = await UserResourceRepository().create() yield user_resource_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py b/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py index fc1ed642a5..01ec7a8116 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py @@ -14,7 +14,7 @@ @pytest_asyncio.fixture async def resource_template_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - resource_template_repo = await ResourceTemplateRepository.create() + resource_template_repo = await ResourceTemplateRepository().create() yield resource_template_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py b/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py index 6aa551aaa8..5ad523ccb5 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py @@ -21,14 +21,14 @@ def basic_workspace_request(): @pytest_asyncio.fixture async def workspace_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): - workspace_repo = await WorkspaceRepository.create() + workspace_repo = await WorkspaceRepository().create() yield workspace_repo @pytest_asyncio.fixture async def operations_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): - operations_repo = await OperationRepository.create() + operations_repo = await OperationRepository().create() yield operations_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py b/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py index 72a55ae9ba..80909411f5 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py @@ -19,14 +19,14 @@ @pytest_asyncio.fixture async def workspace_service_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): - workspace_repo = await WorkspaceServiceRepository.create() + workspace_repo = await WorkspaceServiceRepository().create() yield workspace_repo @pytest_asyncio.fixture async def operations_repo(): with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): - operations_repo = await OperationRepository.create() + operations_repo = await OperationRepository().create() yield operations_repo diff --git a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py index b013914ba2..aa69cddd12 100644 --- a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py @@ -134,9 +134,9 @@ async def test_receiving_bad_json_logs_error(cosmos_client, logging_mock, payloa @patch('service_bus.deployment_status_updater.ResourceHistoryRepository.create') -@patch('service_bus.deployment_status_updater.ResourceTemplateRepository.create') -@patch('service_bus.deployment_status_updater.OperationRepository.create') -@patch('service_bus.deployment_status_updater.ResourceRepository.create') +@patch('service_bus.deployment_status_updater.ResourceTemplateRepository') +@patch('service_bus.deployment_status_updater.OperationRepository') +@patch('service_bus.deployment_status_updater.ResourceRepository') @patch('services.logging.logger.exception') @patch("api.dependencies.database.Database.get_db_client") async def test_receiving_good_message(cosmos_client, logging_mock, resource_repo, operation_repo, _, __): @@ -296,7 +296,7 @@ async def test_properties_dont_change_with_no_outputs(cosmos_client, resource_re @patch('service_bus.deployment_status_updater.ResourceHistoryRepository.create') -@patch('service_bus.deployment_status_updater.ResourceTemplateRepository.create') +@patch('service_bus.deployment_status_updater.ResourceTemplateRepository'.create) @patch('service_bus.deployment_status_updater.update_resource_for_step') @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') diff --git a/api_app/tests_ma/test_service_bus/test_resource_request_sender.py b/api_app/tests_ma/test_service_bus/test_resource_request_sender.py index 71e3785191..e90894791d 100644 --- a/api_app/tests_ma/test_service_bus/test_resource_request_sender.py +++ b/api_app/tests_ma/test_service_bus/test_resource_request_sender.py @@ -40,11 +40,11 @@ def create_test_resource(): @pytest.mark.parametrize( "request_action", [RequestAction.Install, RequestAction.UnInstall] ) -@patch("service_bus.resource_request_sender.ResourceHistoryRepository.create") -@patch("service_bus.resource_request_sender.OperationRepository.create") +@patch("service_bus.resource_request_sender.ResourceHistoryRepository") +@patch("service_bus.resource_request_sender.OperationRepository") @patch("service_bus.helpers.ServiceBusClient") -@patch("service_bus.resource_request_sender.ResourceRepository.create") -@patch("service_bus.resource_request_sender.ResourceTemplateRepository.create") +@patch("service_bus.resource_request_sender.ResourceRepository") +@patch("service_bus.resource_request_sender.ResourceTemplateRepository") async def test_resource_request_message_generated_correctly( resource_template_repo, resource_repo, @@ -84,10 +84,10 @@ async def test_resource_request_message_generated_correctly( assert sent_message_as_json["action"] == request_action -@patch("service_bus.resource_request_sender.ResourceHistoryRepository.create") -@patch("service_bus.resource_request_sender.OperationRepository.create") -@patch("service_bus.resource_request_sender.ResourceRepository.create") -@patch("service_bus.resource_request_sender.ResourceTemplateRepository.create") +@patch("service_bus.resource_request_sender.ResourceHistoryRepository") +@patch("service_bus.resource_request_sender.OperationRepository") +@patch("service_bus.resource_request_sender.ResourceRepository") +@patch("service_bus.resource_request_sender.ResourceTemplateRepository") async def test_multi_step_document_sends_first_step( resource_template_repo, resource_repo, @@ -146,9 +146,9 @@ async def test_multi_step_document_sends_first_step( ) -@patch("service_bus.resource_request_sender.ResourceHistoryRepository.create") -@patch("service_bus.resource_request_sender.ResourceRepository.create") -@patch("service_bus.resource_request_sender.ResourceTemplateRepository.create") +@patch("service_bus.resource_request_sender.ResourceHistoryRepository") +@patch("service_bus.resource_request_sender.ResourceRepository") +@patch("service_bus.resource_request_sender.ResourceTemplateRepository") async def test_multi_step_document_retries( resource_template_repo, resource_repo, diff --git a/api_app/tests_ma/test_services/test_airlock.py b/api_app/tests_ma/test_services/test_airlock.py index 49eabeef1d..c7fccec4d9 100644 --- a/api_app/tests_ma/test_services/test_airlock.py +++ b/api_app/tests_ma/test_services/test_airlock.py @@ -31,7 +31,7 @@ @pytest_asyncio.fixture async def airlock_request_repo_mock(no_database): _ = no_database - airlock_request_repo_mock = await AirlockRequestRepository.create() + airlock_request_repo_mock = await AirlockRequestRepository().create() yield airlock_request_repo_mock From 96c296cd9701393ca8d786bfff78e4fed731f5b5 Mon Sep 17 00:00:00 2001 From: marrobi Date: Tue, 2 Jan 2024 15:36:04 +0000 Subject: [PATCH 22/32] reduce changes --- api_app/api/helpers.py | 21 +++++ api_app/api/routes/airlock.py | 45 +++++------ api_app/api/routes/api.py | 5 +- api_app/api/routes/costs.py | 11 +-- api_app/api/routes/migrations.py | 15 ++-- api_app/api/routes/operations.py | 3 +- .../api/routes/shared_service_templates.py | 7 +- api_app/api/routes/shared_services.py | 17 ++-- api_app/api/routes/user_resource_templates.py | 7 +- .../api/routes/workspace_service_templates.py | 7 +- api_app/api/routes/workspace_templates.py | 7 +- api_app/api/routes/workspaces.py | 79 ++++++++++--------- 12 files changed, 128 insertions(+), 96 deletions(-) create mode 100644 api_app/api/helpers.py diff --git a/api_app/api/helpers.py b/api_app/api/helpers.py new file mode 100644 index 0000000000..8cd137d59a --- /dev/null +++ b/api_app/api/helpers.py @@ -0,0 +1,21 @@ +from typing import Callable, Type + +from fastapi import HTTPException, logger, status + +from db.errors import UnableToAccessDatabase +from db.repositories.base import BaseRepository +from resources.strings import STATE_STORE_ENDPOINT_NOT_RESPONDING + + +def get_repository(repo_type: Type[BaseRepository],) -> Callable: + async def _get_repo() -> BaseRepository: + try: + return await repo_type.create() + except UnableToAccessDatabase: + logger.exception(STATE_STORE_ENDPOINT_NOT_RESPONDING) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=STATE_STORE_ENDPOINT_NOT_RESPONDING, + ) + + return _get_repo diff --git a/api_app/api/routes/airlock.py b/api_app/api/routes/airlock.py index d7aa42b378..7aefa62b41 100644 --- a/api_app/api/routes/airlock.py +++ b/api_app/api/routes/airlock.py @@ -3,6 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException, status as status_code, Response from jsonschema.exceptions import ValidationError +from api.helpers import get_repository from db.repositories.resources_history import ResourceHistoryRepository from db.repositories.user_resources import UserResourceRepository from db.repositories.workspace_services import WorkspaceServiceRepository @@ -35,7 +36,7 @@ response_model=AirlockRequestWithAllowedUserActions, name=strings.API_CREATE_AIRLOCK_REQUEST, dependencies=[Depends(get_current_workspace_owner_or_researcher_user), Depends(get_workspace_by_id_from_path)]) async def create_draft_request(airlock_request_input: AirlockRequestInCreate, user=Depends(get_current_workspace_owner_or_researcher_user), - airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), + airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), workspace=Depends(get_deployed_workspace_by_id_from_path)) -> AirlockRequestWithAllowedUserActions: if workspace.properties.get("enable_airlock") is False: raise HTTPException(status_code=status_code.HTTP_405_METHOD_NOT_ALLOWED, detail=strings.AIRLOCK_NOT_ENABLED_IN_WORKSPACE) @@ -56,7 +57,7 @@ async def create_draft_request(airlock_request_input: AirlockRequestInCreate, us dependencies=[Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) async def get_all_airlock_requests_by_workspace( - airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), + airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), workspace=Depends(get_deployed_workspace_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), creator_user_id: Optional[str] = None, type: Optional[AirlockRequestType] = None, status: Optional[AirlockRequestStatus] = None, @@ -76,7 +77,7 @@ async def get_all_airlock_requests_by_workspace( response_model=AirlockRequestWithAllowedUserActions, name=strings.API_GET_AIRLOCK_REQUEST, dependencies=[Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) async def retrieve_airlock_request_by_id(airlock_request=Depends(get_airlock_request_by_id_from_path), - airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), + airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager)) -> AirlockRequestWithAllowedUserActions: allowed_actions = get_allowed_actions(airlock_request, user, airlock_request_repo) return AirlockRequestWithAllowedUserActions(airlockRequest=airlock_request, allowedUserActions=allowed_actions) @@ -87,7 +88,7 @@ async def retrieve_airlock_request_by_id(airlock_request=Depends(get_airlock_req dependencies=[Depends(get_current_workspace_owner_or_researcher_user), Depends(get_workspace_by_id_from_path)]) async def create_submit_request(airlock_request=Depends(get_airlock_request_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user), - airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), + airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), workspace=Depends(get_workspace_by_id_from_path)) -> AirlockRequestWithAllowedUserActions: updated_request = await update_and_publish_event_airlock_request(airlock_request, airlock_request_repo, user, workspace, new_status=AirlockRequestStatus.Submitted) @@ -101,12 +102,12 @@ async def create_submit_request(airlock_request=Depends(get_airlock_request_by_i async def create_cancel_request(airlock_request=Depends(get_airlock_request_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user), workspace=Depends(get_workspace_by_id_from_path), - airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), - user_resource_repo=Depends(UserResourceRepository.get_repository()), - workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), - resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), - operation_repo=Depends(OperationRepository.get_repository()), - resource_template_repo=Depends(ResourceTemplateRepository.get_repository()),) -> AirlockRequestWithAllowedUserActions: + airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), + user_resource_repo=Depends(get_repository(UserResourceRepository)), + workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), + resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), + operation_repo=Depends(get_repository(OperationRepository)), + resource_template_repo=Depends(get_repository(ResourceTemplateRepository)),) -> AirlockRequestWithAllowedUserActions: updated_request = await cancel_request(airlock_request, user, workspace, airlock_request_repo, user_resource_repo, workspace_service_repo, resource_template_repo, operation_repo, resource_history_repo) allowed_actions = get_allowed_actions(updated_request, user, airlock_request_repo) return AirlockRequestWithAllowedUserActions(airlockRequest=updated_request, allowedUserActions=allowed_actions) @@ -121,12 +122,12 @@ async def create_review_user_resource( airlock_request=Depends(get_airlock_request_by_id_from_path), user=Depends(get_current_airlock_manager_user), workspace=Depends(get_deployed_workspace_by_id_from_path), - user_resource_repo=Depends(UserResourceRepository.get_repository()), - workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), - operation_repo=Depends(OperationRepository.get_repository()), - airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), - resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), - resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> AirlockRequestAndOperationInResponse: + user_resource_repo=Depends(get_repository(UserResourceRepository)), + workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), + operation_repo=Depends(get_repository(OperationRepository)), + airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), + resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), + resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> AirlockRequestAndOperationInResponse: if airlock_request.status != AirlockRequestStatus.InReview: raise HTTPException(status_code=status_code.HTTP_400_BAD_REQUEST, @@ -159,12 +160,12 @@ async def create_airlock_review( airlock_request=Depends(get_airlock_request_by_id_from_path), user=Depends(get_current_airlock_manager_user), workspace=Depends(get_deployed_workspace_by_id_from_path), - airlock_request_repo=Depends(AirlockRequestRepository.get_repository()), - user_resource_repo=Depends(UserResourceRepository.get_repository()), - workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), - operation_repo=Depends(OperationRepository.get_repository()), - resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), - resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> AirlockRequestWithAllowedUserActions: + airlock_request_repo=Depends(get_repository(AirlockRequestRepository)), + user_resource_repo=Depends(get_repository(UserResourceRepository)), + workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), + operation_repo=Depends(get_repository(OperationRepository)), + resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), + resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> AirlockRequestWithAllowedUserActions: try: updated_airlock_request = await review_airlock_request(airlock_review_input, airlock_request, user, workspace, airlock_request_repo, user_resource_repo, workspace_service_repo, operation_repo, resource_template_repo, resource_history_repo) allowed_actions = get_allowed_actions(updated_airlock_request, user, airlock_request_repo) diff --git a/api_app/api/routes/api.py b/api_app/api/routes/api.py index 99bf4b9ee9..3c27782f72 100644 --- a/api_app/api/routes/api.py +++ b/api_app/api/routes/api.py @@ -5,6 +5,7 @@ from fastapi.openapi.docs import get_swagger_ui_html, get_swagger_ui_oauth2_redirect_html from fastapi.openapi.utils import get_openapi +from api.helpers import get_repository from db.repositories.workspaces import WorkspaceRepository from api.routes import health, ping, workspaces, workspace_templates, workspace_service_templates, user_resource_templates, \ shared_services, shared_service_templates, migrations, costs, airlock, operations, metadata @@ -115,7 +116,7 @@ def get_scope(workspace) -> str: @workspace_swagger_router.get("/workspaces/{workspace_id}/openapi.json", include_in_schema=False, name="openapi_definitions") -async def get_openapi_json(workspace_id: str, request: Request, workspace_repo=Depends(WorkspaceRepository.get_repository())): +async def get_openapi_json(workspace_id: str, request: Request, workspace_repo=Depends(get_repository(WorkspaceRepository))): global openapi_definitions if openapi_definitions[workspace_id] is None: @@ -145,7 +146,7 @@ async def get_openapi_json(workspace_id: str, request: Request, workspace_repo=D @workspace_swagger_router.get("/workspaces/{workspace_id}/docs", include_in_schema=False, name="workspace_swagger") -async def get_workspace_swagger(workspace_id, request: Request, workspace_repo=Depends(WorkspaceRepository.get_repository())): +async def get_workspace_swagger(workspace_id, request: Request, workspace_repo=Depends(get_repository(WorkspaceRepository))): workspace = await workspace_repo.get_workspace_by_id(workspace_id) scope = get_scope(workspace) diff --git a/api_app/api/routes/costs.py b/api_app/api/routes/costs.py index 57e1352451..f7e89fa1fc 100644 --- a/api_app/api/routes/costs.py +++ b/api_app/api/routes/costs.py @@ -8,6 +8,7 @@ from models.schemas.costs import get_cost_report_responses, get_workspace_cost_report_responses from core import config +from api.helpers import get_repository from db.repositories.shared_services import SharedServiceRepository from db.repositories.user_resources import UserResourceRepository from db.repositories.workspace_services import WorkspaceServiceRepository @@ -55,8 +56,8 @@ def __init__( async def costs( params: CostsQueryParams = Depends(), cost_service: CostService = Depends(cost_service_factory), - workspace_repo=Depends(WorkspaceRepository.get_repository()), - shared_services_repo=Depends(SharedServiceRepository.get_repository())) -> CostReport: + workspace_repo=Depends(get_repository(WorkspaceRepository)), + shared_services_repo=Depends(get_repository(SharedServiceRepository))) -> CostReport: validate_report_period(params.from_date, params.to_date) try: @@ -89,9 +90,9 @@ async def costs( responses=get_workspace_cost_report_responses()) async def workspace_costs(workspace_id: UUID4, params: CostsQueryParams = Depends(), cost_service: CostService = Depends(cost_service_factory), - workspace_repo=Depends(WorkspaceRepository.get_repository()), - workspace_services_repo=Depends(WorkspaceServiceRepository.get_repository()), - user_resource_repo=Depends(UserResourceRepository.get_repository())) -> WorkspaceCostReport: + workspace_repo=Depends(get_repository(WorkspaceRepository)), + workspace_services_repo=Depends(get_repository(WorkspaceServiceRepository)), + user_resource_repo=Depends(get_repository(UserResourceRepository))) -> WorkspaceCostReport: validate_report_period(params.from_date, params.to_date) try: diff --git a/api_app/api/routes/migrations.py b/api_app/api/routes/migrations.py index cbb420fd1a..692f664583 100644 --- a/api_app/api/routes/migrations.py +++ b/api_app/api/routes/migrations.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from db.migrations.airlock import AirlockMigration from db.migrations.resources import ResourceMigration +from api.helpers import get_repository from db.repositories.operations import OperationRepository from db.repositories.resources_history import ResourceHistoryRepository from services.authentication import get_current_admin_user @@ -19,13 +20,13 @@ name=strings.API_MIGRATE_DATABASE, response_model=MigrationOutList, dependencies=[Depends(get_current_admin_user)]) -async def migrate_database(resources_repo=Depends(ResourceRepository.get_repository()), - operations_repo=Depends(OperationRepository.get_repository()), - resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), - shared_services_migration=Depends(SharedServiceMigration().get_repository()), - workspace_migration=Depends(WorkspaceMigration().get_repository()), - resource_migration=Depends(ResourceMigration().get_repository()), - airlock_migration=Depends(AirlockMigration().get_repository())): +async def migrate_database(resources_repo=Depends(get_repository(ResourceRepository)), + operations_repo=Depends(get_repository(OperationRepository)), + resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), + shared_services_migration=Depends(get_repository(SharedServiceMigration)), + workspace_migration=Depends(get_repository(WorkspaceMigration)), + resource_migration=Depends(get_repository(ResourceMigration)), + airlock_migration=Depends(get_repository(AirlockMigration)),): try: migrations = list() logger.info("PR 1030") diff --git a/api_app/api/routes/operations.py b/api_app/api/routes/operations.py index 0f2d1cc955..0ab67f5be2 100644 --- a/api_app/api/routes/operations.py +++ b/api_app/api/routes/operations.py @@ -1,5 +1,6 @@ from fastapi import APIRouter, Depends +from api.helpers import get_repository from db.repositories.operations import OperationRepository from models.schemas.operation import OperationInList from resources import strings @@ -10,6 +11,6 @@ @operations_router.get("/operations", response_model=OperationInList, name=strings.API_GET_MY_OPERATIONS) -async def get_my_operations(user=Depends(get_current_tre_user_or_tre_admin), operations_repo=Depends(OperationRepository.get_repository())) -> OperationInList: +async def get_my_operations(user=Depends(get_current_tre_user_or_tre_admin), operations_repo=Depends(get_repository(OperationRepository))) -> OperationInList: operations = await operations_repo.get_my_operations(user_id=user.id) return OperationInList(operations=operations) diff --git a/api_app/api/routes/shared_service_templates.py b/api_app/api/routes/shared_service_templates.py index 57effd152c..b7801c3789 100644 --- a/api_app/api/routes/shared_service_templates.py +++ b/api_app/api/routes/shared_service_templates.py @@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as +from api.helpers import get_repository from db.errors import EntityDoesNotExist, EntityVersionExist, InvalidInput from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType @@ -16,13 +17,13 @@ @shared_service_templates_core_router.get("/shared-service-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_SHARED_SERVICE_TEMPLATES) -async def get_shared_service_templates(authorized_only: bool = False, template_repo=Depends(ResourceTemplateRepository.get_repository()), user=Depends(get_current_tre_user_or_tre_admin)) -> ResourceTemplateInformationInList: +async def get_shared_service_templates(authorized_only: bool = False, template_repo=Depends(get_repository(ResourceTemplateRepository)), user=Depends(get_current_tre_user_or_tre_admin)) -> ResourceTemplateInformationInList: templates_infos = await template_repo.get_templates_information(ResourceType.SharedService, user.roles if authorized_only else None) return ResourceTemplateInformationInList(templates=templates_infos) @shared_service_templates_core_router.get("/shared-service-templates/{shared_service_template_name}", response_model=SharedServiceTemplateInResponse, response_model_exclude_none=True, name=strings.API_GET_SHARED_SERVICE_TEMPLATE_BY_NAME, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_shared_service_template(shared_service_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> SharedServiceTemplateInResponse: +async def get_shared_service_template(shared_service_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> SharedServiceTemplateInResponse: try: template = await get_template(shared_service_template_name, template_repo, ResourceType.SharedService, is_update=is_update, version=version) return parse_obj_as(SharedServiceTemplateInResponse, template) @@ -31,7 +32,7 @@ async def get_shared_service_template(shared_service_template_name: str, is_upda @shared_service_templates_core_router.post("/shared-service-templates", status_code=status.HTTP_201_CREATED, response_model=SharedServiceTemplateInResponse, response_model_exclude_none=True, name=strings.API_CREATE_SHARED_SERVICE_TEMPLATES, dependencies=[Depends(get_current_admin_user)]) -async def register_shared_service_template(template_input: SharedServiceTemplateInCreate, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> ResourceTemplateInResponse: +async def register_shared_service_template(template_input: SharedServiceTemplateInCreate, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> ResourceTemplateInResponse: try: return await template_repo.create_and_validate_template(template_input, ResourceType.SharedService) except EntityVersionExist: diff --git a/api_app/api/routes/shared_services.py b/api_app/api/routes/shared_services.py index 2f7831b728..6e23945bdd 100644 --- a/api_app/api/routes/shared_services.py +++ b/api_app/api/routes/shared_services.py @@ -5,6 +5,7 @@ from db.repositories.operations import OperationRepository from db.errors import DuplicateEntity, MajorVersionUpdateDenied, UserNotAuthorizedToUseTemplate, TargetTemplateVersionDoesNotExist, VersionDowngradeDenied +from api.helpers import get_repository from api.dependencies.shared_services import get_shared_service_by_id_from_path, get_operation_by_id_from_path from db.repositories.resource_templates import ResourceTemplateRepository from db.repositories.resources_history import ResourceHistoryRepository @@ -32,7 +33,7 @@ def user_is_tre_admin(user): @shared_services_router.get("/shared-services", response_model=SharedServicesInList, name=strings.API_GET_ALL_SHARED_SERVICES, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def retrieve_shared_services(shared_services_repo=Depends(SharedServiceRepository.get_repository()), user=Depends(get_current_tre_user_or_tre_admin), resource_template_repo=Depends(ResourceTemplateRepository.get_repository())) -> SharedServicesInList: +async def retrieve_shared_services(shared_services_repo=Depends(get_repository(SharedServiceRepository)), user=Depends(get_current_tre_user_or_tre_admin), resource_template_repo=Depends(get_repository(ResourceTemplateRepository))) -> SharedServicesInList: shared_services = await shared_services_repo.get_active_shared_services() await asyncio.gather(*[enrich_resource_with_available_upgrades(shared_service, resource_template_repo) for shared_service in shared_services]) if user_is_tre_admin(user): @@ -42,7 +43,7 @@ async def retrieve_shared_services(shared_services_repo=Depends(SharedServiceRep @shared_services_router.get("/shared-services/{shared_service_id}", response_model=SharedServiceInResponse, name=strings.API_GET_SHARED_SERVICE_BY_ID, dependencies=[Depends(get_current_tre_user_or_tre_admin), Depends(get_shared_service_by_id_from_path)]) -async def retrieve_shared_service_by_id(shared_service=Depends(get_shared_service_by_id_from_path), user=Depends(get_current_tre_user_or_tre_admin), resource_template_repo=Depends(ResourceTemplateRepository.get_repository())): +async def retrieve_shared_service_by_id(shared_service=Depends(get_shared_service_by_id_from_path), user=Depends(get_current_tre_user_or_tre_admin), resource_template_repo=Depends(get_repository(ResourceTemplateRepository))): await enrich_resource_with_available_upgrades(shared_service, resource_template_repo) if user_is_tre_admin(user): return SharedServiceInResponse(sharedService=shared_service) @@ -51,7 +52,7 @@ async def retrieve_shared_service_by_id(shared_service=Depends(get_shared_servic @shared_services_router.post("/shared-services", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_CREATE_SHARED_SERVICE, dependencies=[Depends(get_current_admin_user)]) -async def create_shared_service(response: Response, shared_service_input: SharedServiceInCreate, user=Depends(get_current_admin_user), shared_services_repo=Depends(SharedServiceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: +async def create_shared_service(response: Response, shared_service_input: SharedServiceInCreate, user=Depends(get_current_admin_user), shared_services_repo=Depends(get_repository(SharedServiceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: try: shared_service, resource_template = await shared_services_repo.create_shared_service_item(shared_service_input, user.roles) except (ValidationError, ValueError) as e: @@ -82,7 +83,7 @@ async def create_shared_service(response: Response, shared_service_input: Shared response_model=OperationInResponse, name=strings.API_UPDATE_SHARED_SERVICE, dependencies=[Depends(get_current_admin_user), Depends(get_shared_service_by_id_from_path)]) -async def patch_shared_service(shared_service_patch: ResourcePatch, response: Response, user=Depends(get_current_admin_user), shared_service_repo=Depends(SharedServiceRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), shared_service=Depends(get_shared_service_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), etag: str = Header(...), force_version_update: bool = False) -> SharedServiceInResponse: +async def patch_shared_service(shared_service_patch: ResourcePatch, response: Response, user=Depends(get_current_admin_user), shared_service_repo=Depends(get_repository(SharedServiceRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), shared_service=Depends(get_shared_service_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), etag: str = Header(...), force_version_update: bool = False) -> SharedServiceInResponse: try: patched_shared_service, _ = await shared_service_repo.patch_shared_service(shared_service, shared_service_patch, etag, resource_template_repo, resource_history_repo, user, force_version_update) operation = await send_resource_request_message( @@ -105,7 +106,7 @@ async def patch_shared_service(shared_service_patch: ResourcePatch, response: Re @shared_services_router.delete("/shared-services/{shared_service_id}", response_model=OperationInResponse, name=strings.API_DELETE_SHARED_SERVICE, dependencies=[Depends(get_current_admin_user)]) -async def delete_shared_service(response: Response, user=Depends(get_current_admin_user), shared_service=Depends(get_shared_service_by_id_from_path), operations_repo=Depends(OperationRepository.get_repository()), shared_service_repo=Depends(SharedServiceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: +async def delete_shared_service(response: Response, user=Depends(get_current_admin_user), shared_service=Depends(get_shared_service_by_id_from_path), operations_repo=Depends(get_repository(OperationRepository)), shared_service_repo=Depends(get_repository(SharedServiceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: if shared_service.isEnabled: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=strings.SHARED_SERVICE_NEEDS_TO_BE_DISABLED_BEFORE_DELETION) @@ -124,7 +125,7 @@ async def delete_shared_service(response: Response, user=Depends(get_current_adm @shared_services_router.post("/shared-services/{shared_service_id}/invoke-action", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_INVOKE_ACTION_ON_SHARED_SERVICE, dependencies=[Depends(get_current_admin_user)]) -async def invoke_action_on_shared_service(response: Response, action: str, user=Depends(get_current_admin_user), shared_service=Depends(get_shared_service_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), shared_service_repo=Depends(SharedServiceRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: +async def invoke_action_on_shared_service(response: Response, action: str, user=Depends(get_current_admin_user), shared_service=Depends(get_shared_service_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), shared_service_repo=Depends(get_repository(SharedServiceRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: operation = await send_custom_action_message( resource=shared_service, resource_repo=shared_service_repo, @@ -142,7 +143,7 @@ async def invoke_action_on_shared_service(response: Response, action: str, user= # Shared service operations @shared_services_router.get("/shared-services/{shared_service_id}/operations", response_model=OperationInList, name=strings.API_GET_RESOURCE_OPERATIONS, dependencies=[Depends(get_current_admin_user), Depends(get_shared_service_by_id_from_path)]) -async def retrieve_shared_service_operations_by_shared_service_id(shared_service=Depends(get_shared_service_by_id_from_path), operations_repo=Depends(OperationRepository.get_repository())) -> OperationInList: +async def retrieve_shared_service_operations_by_shared_service_id(shared_service=Depends(get_shared_service_by_id_from_path), operations_repo=Depends(get_repository(OperationRepository))) -> OperationInList: return OperationInList(operations=await operations_repo.get_operations_by_resource_id(resource_id=shared_service.id)) @@ -153,5 +154,5 @@ async def retrieve_shared_service_operation_by_shared_service_id_and_operation_i # Shared service history @shared_services_router.get("/shared-services/{shared_service_id}/history", response_model=ResourceHistoryInList, name=strings.API_GET_RESOURCE_HISTORY, dependencies=[Depends(get_current_admin_user)]) -async def retrieve_shared_service_history_by_shared_service_id(shared_service=Depends(get_shared_service_by_id_from_path), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> ResourceHistoryInList: +async def retrieve_shared_service_history_by_shared_service_id(shared_service=Depends(get_shared_service_by_id_from_path), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: return ResourceHistoryInList(resource_history=await resource_history_repo.get_resource_history_by_resource_id(resource_id=shared_service.id)) diff --git a/api_app/api/routes/user_resource_templates.py b/api_app/api/routes/user_resource_templates.py index 4330a3b688..f4cb24cb5f 100644 --- a/api_app/api/routes/user_resource_templates.py +++ b/api_app/api/routes/user_resource_templates.py @@ -6,6 +6,7 @@ from api.dependencies.workspace_service_templates import get_workspace_service_template_by_name_from_path from api.routes.resource_helpers import get_template from db.errors import EntityVersionExist, InvalidInput +from api.helpers import get_repository from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType from models.schemas.user_resource_template import UserResourceTemplateInResponse, UserResourceTemplateInCreate @@ -18,19 +19,19 @@ @user_resource_templates_core_router.get("/workspace-service-templates/{service_template_name}/user-resource-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_USER_RESOURCE_TEMPLATES, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_user_resource_templates_for_service_template(service_template_name: str, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> ResourceTemplateInformationInList: +async def get_user_resource_templates_for_service_template(service_template_name: str, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> ResourceTemplateInformationInList: template_infos = await template_repo.get_templates_information(ResourceType.UserResource, parent_service_name=service_template_name) return ResourceTemplateInformationInList(templates=template_infos) @user_resource_templates_core_router.get("/workspace-service-templates/{service_template_name}/user-resource-templates/{user_resource_template_name}", response_model=UserResourceTemplateInResponse, response_model_exclude_none=True, name=strings.API_GET_USER_RESOURCE_TEMPLATE_BY_NAME, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_user_resource_template(service_template_name: str, user_resource_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> UserResourceTemplateInResponse: +async def get_user_resource_template(service_template_name: str, user_resource_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> UserResourceTemplateInResponse: template = await get_template(user_resource_template_name, template_repo, ResourceType.UserResource, service_template_name, is_update=is_update, version=version) return parse_obj_as(UserResourceTemplateInResponse, template) @user_resource_templates_core_router.post("/workspace-service-templates/{service_template_name}/user-resource-templates", status_code=status.HTTP_201_CREATED, response_model=UserResourceTemplateInResponse, response_model_exclude_none=True, name=strings.API_CREATE_USER_RESOURCE_TEMPLATES, dependencies=[Depends(get_current_admin_user)]) -async def register_user_resource_template(template_input: UserResourceTemplateInCreate, template_repo=Depends(ResourceTemplateRepository.get_repository()), workspace_service_template=Depends(get_workspace_service_template_by_name_from_path)) -> UserResourceTemplateInResponse: +async def register_user_resource_template(template_input: UserResourceTemplateInCreate, template_repo=Depends(get_repository(ResourceTemplateRepository)), workspace_service_template=Depends(get_workspace_service_template_by_name_from_path)) -> UserResourceTemplateInResponse: try: return await template_repo.create_and_validate_template(template_input, ResourceType.UserResource, workspace_service_template.name) except EntityVersionExist: diff --git a/api_app/api/routes/workspace_service_templates.py b/api_app/api/routes/workspace_service_templates.py index 6411e15e1b..e6df3fadba 100644 --- a/api_app/api/routes/workspace_service_templates.py +++ b/api_app/api/routes/workspace_service_templates.py @@ -4,6 +4,7 @@ from api.routes.resource_helpers import get_template from db.errors import EntityVersionExist, InvalidInput +from api.helpers import get_repository from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType from models.schemas.resource_template import ResourceTemplateInResponse, ResourceTemplateInformationInList @@ -16,19 +17,19 @@ @workspace_service_templates_core_router.get("/workspace-service-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_WORKSPACE_SERVICE_TEMPLATES, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_workspace_service_templates(template_repo=Depends(ResourceTemplateRepository.get_repository())) -> ResourceTemplateInformationInList: +async def get_workspace_service_templates(template_repo=Depends(get_repository(ResourceTemplateRepository))) -> ResourceTemplateInformationInList: templates_infos = await template_repo.get_templates_information(ResourceType.WorkspaceService) return ResourceTemplateInformationInList(templates=templates_infos) @workspace_service_templates_core_router.get("/workspace-service-templates/{service_template_name}", response_model=WorkspaceServiceTemplateInResponse, response_model_exclude_none=True, name=strings.API_GET_WORKSPACE_SERVICE_TEMPLATE_BY_NAME, dependencies=[Depends(get_current_tre_user_or_tre_admin)]) -async def get_workspace_service_template(service_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> WorkspaceServiceTemplateInResponse: +async def get_workspace_service_template(service_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> WorkspaceServiceTemplateInResponse: template = await get_template(service_template_name, template_repo, ResourceType.WorkspaceService, is_update=is_update, version=version) return parse_obj_as(WorkspaceServiceTemplateInResponse, template) @workspace_service_templates_core_router.post("/workspace-service-templates", status_code=status.HTTP_201_CREATED, response_model=WorkspaceServiceTemplateInResponse, response_model_exclude_none=True, name=strings.API_CREATE_WORKSPACE_SERVICE_TEMPLATES, dependencies=[Depends(get_current_admin_user)]) -async def register_workspace_service_template(template_input: WorkspaceServiceTemplateInCreate, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> ResourceTemplateInResponse: +async def register_workspace_service_template(template_input: WorkspaceServiceTemplateInCreate, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> ResourceTemplateInResponse: try: return await template_repo.create_and_validate_template(template_input, ResourceType.WorkspaceService) except EntityVersionExist: diff --git a/api_app/api/routes/workspace_templates.py b/api_app/api/routes/workspace_templates.py index 8dd0efda4c..c61b6f7e82 100644 --- a/api_app/api/routes/workspace_templates.py +++ b/api_app/api/routes/workspace_templates.py @@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import parse_obj_as +from api.helpers import get_repository from db.errors import EntityVersionExist, InvalidInput from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType @@ -16,19 +17,19 @@ @workspace_templates_admin_router.get("/workspace-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_WORKSPACE_TEMPLATES) -async def get_workspace_templates(authorized_only: bool = False, template_repo=Depends(ResourceTemplateRepository.get_repository()), user=Depends(get_current_admin_user)) -> ResourceTemplateInformationInList: +async def get_workspace_templates(authorized_only: bool = False, template_repo=Depends(get_repository(ResourceTemplateRepository)), user=Depends(get_current_admin_user)) -> ResourceTemplateInformationInList: templates_infos = await template_repo.get_templates_information(ResourceType.Workspace, user.roles if authorized_only else None) return ResourceTemplateInformationInList(templates=templates_infos) @workspace_templates_admin_router.get("/workspace-templates/{workspace_template_name}", response_model=WorkspaceTemplateInResponse, name=strings.API_GET_WORKSPACE_TEMPLATE_BY_NAME, response_model_exclude_none=True) -async def get_workspace_template(workspace_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> WorkspaceTemplateInResponse: +async def get_workspace_template(workspace_template_name: str, is_update: bool = False, version: Optional[str] = None, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> WorkspaceTemplateInResponse: template = await get_template(workspace_template_name, template_repo, ResourceType.Workspace, is_update=is_update, version=version) return parse_obj_as(WorkspaceTemplateInResponse, template) @workspace_templates_admin_router.post("/workspace-templates", status_code=status.HTTP_201_CREATED, response_model=WorkspaceTemplateInResponse, response_model_exclude_none=True, name=strings.API_CREATE_WORKSPACE_TEMPLATES) -async def register_workspace_template(template_input: WorkspaceTemplateInCreate, template_repo=Depends(ResourceTemplateRepository.get_repository())) -> ResourceTemplateInResponse: +async def register_workspace_template(template_input: WorkspaceTemplateInCreate, template_repo=Depends(get_repository(ResourceTemplateRepository))) -> ResourceTemplateInResponse: try: return await template_repo.create_and_validate_template(template_input, ResourceType.Workspace) except EntityVersionExist: diff --git a/api_app/api/routes/workspaces.py b/api_app/api/routes/workspaces.py index 9c2d38caf6..018a21999c 100644 --- a/api_app/api/routes/workspaces.py +++ b/api_app/api/routes/workspaces.py @@ -4,6 +4,7 @@ from jsonschema.exceptions import ValidationError +from api.helpers import get_repository from api.dependencies.workspaces import get_operation_by_id_from_path, get_workspace_by_id_from_path, get_deployed_workspace_by_id_from_path, get_deployed_workspace_service_by_id_from_path, get_workspace_service_by_id_from_path, get_user_resource_by_id_from_path from db.errors import InvalidInput, MajorVersionUpdateDenied, TargetTemplateVersionDoesNotExist, UserNotAuthorizedToUseTemplate, VersionDowngradeDenied from db.repositories.operations import OperationRepository @@ -55,7 +56,7 @@ def validate_user_has_valid_role_for_user_resource(user, user_resource): # WORKSPACE ROUTES @workspaces_core_router.get("/workspaces", response_model=WorkspacesInList, name=strings.API_GET_ALL_WORKSPACES) -async def retrieve_users_active_workspaces(request: Request, user=Depends(get_current_tre_user_or_tre_admin), workspace_repo=Depends(WorkspaceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository())) -> WorkspacesInList: +async def retrieve_users_active_workspaces(request: Request, user=Depends(get_current_tre_user_or_tre_admin), workspace_repo=Depends(get_repository(WorkspaceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository))) -> WorkspacesInList: try: user = await get_current_admin_user(request) @@ -82,7 +83,7 @@ def _safe_get_workspace_role(user, workspace, user_role_assignments): @workspaces_shared_router.get("/workspaces/{workspace_id}", response_model=WorkspaceInResponse, name=strings.API_GET_WORKSPACE_BY_ID) -async def retrieve_workspace_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository())) -> WorkspaceInResponse: +async def retrieve_workspace_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository))) -> WorkspaceInResponse: await enrich_resource_with_available_upgrades(workspace, resource_template_repo) return WorkspaceInResponse(workspace=workspace) @@ -96,7 +97,7 @@ async def retrieve_workspace_scope_id_by_workspace_id(workspace=Depends(get_work @workspaces_core_router.post("/workspaces", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_CREATE_WORKSPACE, dependencies=[Depends(get_current_admin_user)]) -async def create_workspace(workspace_create: WorkspaceInCreate, response: Response, user=Depends(get_current_admin_user), workspace_repo=Depends(WorkspaceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: +async def create_workspace(workspace_create: WorkspaceInCreate, response: Response, user=Depends(get_current_admin_user), workspace_repo=Depends(get_repository(WorkspaceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: try: # TODO: This requires Directory.ReadAll ( Application.Read.All ) to be enabled in the Azure AD application to enable a users workspaces to be listed. This should be made optional. auth_info = extract_auth_information(workspace_create.properties) @@ -124,7 +125,7 @@ async def create_workspace(workspace_create: WorkspaceInCreate, response: Respon @workspaces_core_router.patch("/workspaces/{workspace_id}", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_UPDATE_WORKSPACE, dependencies=[Depends(get_current_admin_user)]) -async def patch_workspace(resource_patch: ResourcePatch, response: Response, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), workspace_repo: WorkspaceRepository = Depends(WorkspaceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: +async def patch_workspace(resource_patch: ResourcePatch, response: Response, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), workspace_repo: WorkspaceRepository = Depends(get_repository(WorkspaceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: try: is_disablement = resource_patch.isEnabled is not None and not resource_patch.isEnabled if is_disablement: @@ -152,7 +153,7 @@ async def patch_workspace(resource_patch: ResourcePatch, response: Response, use @workspaces_core_router.delete("/workspaces/{workspace_id}", response_model=OperationInResponse, name=strings.API_DELETE_WORKSPACE, dependencies=[Depends(get_current_admin_user)]) -async def delete_workspace(response: Response, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), operations_repo=Depends(OperationRepository.get_repository()), workspace_repo=Depends(WorkspaceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: +async def delete_workspace(response: Response, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), operations_repo=Depends(get_repository(OperationRepository)), workspace_repo=Depends(get_repository(WorkspaceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: if await delete_validation(workspace, workspace_repo): operation = await send_uninstall_message( resource=workspace, @@ -170,7 +171,7 @@ async def delete_workspace(response: Response, user=Depends(get_current_admin_us @workspaces_core_router.post("/workspaces/{workspace_id}/invoke-action", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_INVOKE_ACTION_ON_WORKSPACE, dependencies=[Depends(get_current_admin_user)]) -async def invoke_action_on_workspace(response: Response, action: str, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), workspace_repo=Depends(WorkspaceRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: +async def invoke_action_on_workspace(response: Response, action: str, user=Depends(get_current_admin_user), workspace=Depends(get_workspace_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), workspace_repo=Depends(get_repository(WorkspaceRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: operation = await send_custom_action_message( resource=workspace, resource_repo=workspace_repo, @@ -191,7 +192,7 @@ async def invoke_action_on_workspace(response: Response, action: str, user=Depen @workspaces_shared_router.get("/workspaces/{workspace_id}/workspace-service-templates", response_model=ResourceTemplateInformationInList, name=strings.API_GET_WORKSPACE_SERVICE_TEMPLATES_IN_WORKSPACE) async def get_workspace_service_templates( workspace=Depends(get_workspace_by_id_from_path), - template_repo=Depends(ResourceTemplateRepository.get_repository()), + template_repo=Depends(get_repository(ResourceTemplateRepository)), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager_or_tre_admin)) -> ResourceTemplateInformationInList: template_infos = await template_repo.get_templates_information(ResourceType.WorkspaceService, user.roles) return ResourceTemplateInformationInList(templates=template_infos) @@ -202,14 +203,14 @@ async def get_workspace_service_templates( async def get_user_resource_templates( service_template_name: str, workspace=Depends(get_workspace_by_id_from_path), - template_repo=Depends(ResourceTemplateRepository.get_repository()), + template_repo=Depends(get_repository(ResourceTemplateRepository)), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager_or_tre_admin)) -> ResourceTemplateInformationInList: template_infos = await template_repo.get_templates_information(ResourceType.UserResource, user.roles, service_template_name) return ResourceTemplateInformationInList(templates=template_infos) @workspaces_shared_router.get("/workspaces/{workspace_id}/operations", response_model=OperationInList, name=strings.API_GET_RESOURCE_OPERATIONS, dependencies=[Depends(get_current_workspace_owner_or_tre_admin)]) -async def retrieve_workspace_operations_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), operations_repo=Depends(OperationRepository.get_repository())) -> OperationInList: +async def retrieve_workspace_operations_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), operations_repo=Depends(get_repository(OperationRepository))) -> OperationInList: return OperationInList(operations=await operations_repo.get_operations_by_resource_id(resource_id=workspace.id)) @@ -219,26 +220,26 @@ async def retrieve_workspace_operation_by_workspace_id_and_operation_id(workspac @workspaces_shared_router.get("/workspaces/{workspace_id}/history", response_model=ResourceHistoryInList, name=strings.API_GET_RESOURCE_HISTORY, dependencies=[Depends(get_current_workspace_owner_or_tre_admin)]) -async def retrieve_workspace_history_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> ResourceHistoryInList: +async def retrieve_workspace_history_by_workspace_id(workspace=Depends(get_workspace_by_id_from_path), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: return ResourceHistoryInList(resource_history=await resource_history_repo.get_resource_history_by_resource_id(resource_id=workspace.id)) # WORKSPACE SERVICES ROUTES @workspace_services_workspace_router.get("/workspaces/{workspace_id}/workspace-services", response_model=WorkspaceServicesInList, name=strings.API_GET_ALL_WORKSPACE_SERVICES, dependencies=[Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager)]) -async def retrieve_users_active_workspace_services(workspace=Depends(get_workspace_by_id_from_path), workspace_services_repo=Depends(WorkspaceServiceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository())) -> WorkspaceServicesInList: +async def retrieve_users_active_workspace_services(workspace=Depends(get_workspace_by_id_from_path), workspace_services_repo=Depends(get_repository(WorkspaceServiceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository))) -> WorkspaceServicesInList: workspace_services = await workspace_services_repo.get_active_workspace_services_for_workspace(workspace.id) await asyncio.gather(*[enrich_resource_with_available_upgrades(workspace_service, resource_template_repo) for workspace_service in workspace_services]) return WorkspaceServicesInList(workspaceServices=workspace_services) @workspace_services_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}", response_model=WorkspaceServiceInResponse, name=strings.API_GET_WORKSPACE_SERVICE_BY_ID, dependencies=[Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) -async def retrieve_workspace_service_by_id(workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository())) -> WorkspaceServiceInResponse: +async def retrieve_workspace_service_by_id(workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository))) -> WorkspaceServiceInResponse: await enrich_resource_with_available_upgrades(workspace_service, resource_template_repo) return WorkspaceServiceInResponse(workspaceService=workspace_service) @workspace_services_workspace_router.post("/workspaces/{workspace_id}/workspace-services", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_CREATE_WORKSPACE_SERVICE, dependencies=[Depends(get_current_workspace_owner_user)]) -async def create_workspace_service(response: Response, workspace_service_input: WorkspaceServiceInCreate, user=Depends(get_current_workspace_owner_user), workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), workspace_repo=Depends(WorkspaceRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), workspace=Depends(get_deployed_workspace_by_id_from_path)) -> OperationInResponse: +async def create_workspace_service(response: Response, workspace_service_input: WorkspaceServiceInCreate, user=Depends(get_current_workspace_owner_user), workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), workspace_repo=Depends(get_repository(WorkspaceRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), workspace=Depends(get_deployed_workspace_by_id_from_path)) -> OperationInResponse: try: workspace_service, resource_template = await workspace_service_repo.create_workspace_service_item(workspace_service_input, workspace.id, user.roles) @@ -279,7 +280,7 @@ async def create_workspace_service(response: Response, workspace_service_input: @workspace_services_workspace_router.patch("/workspaces/{workspace_id}/workspace-services/{service_id}", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_UPDATE_WORKSPACE_SERVICE, dependencies=[Depends(get_current_workspace_owner_or_researcher_user), Depends(get_workspace_by_id_from_path)]) -async def patch_workspace_service(resource_patch: ResourcePatch, response: Response, user=Depends(get_current_workspace_owner_user), workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: +async def patch_workspace_service(resource_patch: ResourcePatch, response: Response, user=Depends(get_current_workspace_owner_user), workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: try: is_disablement = resource_patch.isEnabled is not None and not resource_patch.isEnabled if is_disablement: @@ -305,7 +306,7 @@ async def patch_workspace_service(resource_patch: ResourcePatch, response: Respo @workspace_services_workspace_router.delete("/workspaces/{workspace_id}/workspace-services/{service_id}", response_model=OperationInResponse, name=strings.API_DELETE_WORKSPACE_SERVICE, dependencies=[Depends(get_current_workspace_owner_user)]) -async def delete_workspace_service(response: Response, user=Depends(get_current_workspace_owner_user), workspace=Depends(get_workspace_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), user_resource_repo=Depends(UserResourceRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: +async def delete_workspace_service(response: Response, user=Depends(get_current_workspace_owner_user), workspace=Depends(get_workspace_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), user_resource_repo=Depends(get_repository(UserResourceRepository)), operations_repo=Depends(get_repository(OperationRepository)), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: if await delete_validation(workspace_service, workspace_service_repo): operation = await send_uninstall_message( resource=workspace_service, @@ -323,7 +324,7 @@ async def delete_workspace_service(response: Response, user=Depends(get_current_ @workspace_services_workspace_router.post("/workspaces/{workspace_id}/workspace-services/{service_id}/invoke-action", status_code=status.HTTP_202_ACCEPTED, response_model=OperationInResponse, name=strings.API_INVOKE_ACTION_ON_WORKSPACE_SERVICE, dependencies=[Depends(get_current_workspace_owner_user)]) -async def invoke_action_on_workspace_service(response: Response, action: str, user=Depends(get_current_workspace_owner_user), workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), operations_repo=Depends(OperationRepository.get_repository()), workspace_service_repo=Depends(WorkspaceServiceRepository.get_repository()), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: +async def invoke_action_on_workspace_service(response: Response, action: str, user=Depends(get_current_workspace_owner_user), workspace_service=Depends(get_workspace_service_by_id_from_path), resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), operations_repo=Depends(get_repository(OperationRepository)), workspace_service_repo=Depends(get_repository(WorkspaceServiceRepository)), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: operation = await send_custom_action_message( resource=workspace_service, resource_repo=workspace_service_repo, @@ -341,7 +342,7 @@ async def invoke_action_on_workspace_service(response: Response, action: str, us # workspace service operations @workspace_services_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}/operations", response_model=OperationInList, name=strings.API_GET_RESOURCE_OPERATIONS, dependencies=[Depends(get_current_workspace_owner_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) -async def retrieve_workspace_service_operations_by_workspace_service_id(workspace_service=Depends(get_workspace_service_by_id_from_path), operations_repo=Depends(OperationRepository.get_repository())) -> OperationInList: +async def retrieve_workspace_service_operations_by_workspace_service_id(workspace_service=Depends(get_workspace_service_by_id_from_path), operations_repo=Depends(get_repository(OperationRepository))) -> OperationInList: return OperationInList(operations=await operations_repo.get_operations_by_resource_id(resource_id=workspace_service.id)) @@ -351,7 +352,7 @@ async def retrieve_workspace_service_operation_by_workspace_service_id_and_opera @workspace_services_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}/history", response_model=ResourceHistoryInList, name=strings.API_GET_RESOURCE_HISTORY, dependencies=[Depends(get_current_workspace_owner_or_airlock_manager), Depends(get_workspace_by_id_from_path)]) -async def retrieve_workspace_service_history_by_workspace_service_id(workspace_service=Depends(get_workspace_service_by_id_from_path), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> ResourceHistoryInList: +async def retrieve_workspace_service_history_by_workspace_service_id(workspace_service=Depends(get_workspace_service_by_id_from_path), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: return ResourceHistoryInList(resource_history=await resource_history_repo.get_resource_history_by_resource_id(resource_id=workspace_service.id)) @@ -361,8 +362,8 @@ async def retrieve_user_resources_for_workspace_service( workspace_id: str, service_id: str, user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), - resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), - user_resource_repo=Depends(UserResourceRepository.get_repository())) -> UserResourcesInList: + resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), + user_resource_repo=Depends(get_repository(UserResourceRepository))) -> UserResourcesInList: user_resources = await user_resource_repo.get_user_resources_for_workspace_service(workspace_id, service_id) # filter only to the user - for researchers @@ -381,7 +382,7 @@ async def retrieve_user_resources_for_workspace_service( @user_resources_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}/user-resources/{resource_id}", response_model=UserResourceInResponse, name=strings.API_GET_USER_RESOURCE, dependencies=[Depends(get_workspace_by_id_from_path)]) async def retrieve_user_resource_by_id( user_resource=Depends(get_user_resource_by_id_from_path), - resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), + resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager)) -> UserResourceInResponse: validate_user_has_valid_role_for_user_resource(user, user_resource) @@ -396,10 +397,10 @@ async def retrieve_user_resource_by_id( async def create_user_resource( response: Response, user_resource_create: UserResourceInCreate, - user_resource_repo=Depends(UserResourceRepository.get_repository()), - resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), - operations_repo=Depends(OperationRepository.get_repository()), - resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), + user_resource_repo=Depends(get_repository(UserResourceRepository)), + resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), + operations_repo=Depends(get_repository(OperationRepository)), + resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), workspace=Depends(get_deployed_workspace_by_id_from_path), workspace_service=Depends(get_deployed_workspace_service_by_id_from_path)) -> OperationInResponse: @@ -432,10 +433,10 @@ async def delete_user_resource( user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), user_resource=Depends(get_user_resource_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), - user_resource_repo=Depends(UserResourceRepository.get_repository()), - operations_repo=Depends(OperationRepository.get_repository()), - resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), - resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> OperationInResponse: + user_resource_repo=Depends(get_repository(UserResourceRepository)), + operations_repo=Depends(get_repository(OperationRepository)), + resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), + resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> OperationInResponse: validate_user_has_valid_role_for_user_resource(user, user_resource) if user_resource.isEnabled: @@ -462,10 +463,10 @@ async def patch_user_resource( user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), user_resource=Depends(get_user_resource_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), - user_resource_repo=Depends(UserResourceRepository.get_repository()), - resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), - resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), - operations_repo=Depends(OperationRepository.get_repository()), + user_resource_repo=Depends(get_repository(UserResourceRepository)), + resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), + resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), + operations_repo=Depends(get_repository(OperationRepository)), etag: str = Header(...), force_version_update: bool = False) -> OperationInResponse: validate_user_has_valid_role_for_user_resource(user, user_resource) @@ -489,10 +490,10 @@ async def invoke_action_on_user_resource( action: str, user_resource=Depends(get_user_resource_by_id_from_path), workspace_service=Depends(get_workspace_service_by_id_from_path), - resource_template_repo=Depends(ResourceTemplateRepository.get_repository()), - user_resource_repo=Depends(UserResourceRepository.get_repository()), - operations_repo=Depends(OperationRepository.get_repository()), - resource_history_repo=Depends(ResourceHistoryRepository.get_repository()), + resource_template_repo=Depends(get_repository(ResourceTemplateRepository)), + user_resource_repo=Depends(get_repository(UserResourceRepository)), + operations_repo=Depends(get_repository(OperationRepository)), + resource_history_repo=Depends(get_repository(ResourceHistoryRepository)), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager)) -> OperationInResponse: validate_user_has_valid_role_for_user_resource(user, user_resource) operation = await send_custom_action_message( @@ -516,7 +517,7 @@ async def invoke_action_on_user_resource( async def retrieve_user_resource_operations_by_user_resource_id( user_resource=Depends(get_user_resource_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), - operations_repo=Depends(OperationRepository.get_repository())) -> OperationInList: + operations_repo=Depends(get_repository(OperationRepository))) -> OperationInList: validate_user_has_valid_role_for_user_resource(user, user_resource) return OperationInList(operations=await operations_repo.get_operations_by_resource_id(resource_id=user_resource.id)) @@ -531,6 +532,6 @@ async def retrieve_user_resource_operations_by_user_resource_id_and_operation_id @user_resources_workspace_router.get("/workspaces/{workspace_id}/workspace-services/{service_id}/user-resources/{resource_id}/history", response_model=ResourceHistoryInList, name=strings.API_GET_RESOURCE_HISTORY, dependencies=[Depends(get_workspace_by_id_from_path)]) -async def retrieve_user_resource_history_by_user_resource_id(user_resource=Depends(get_user_resource_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), resource_history_repo=Depends(ResourceHistoryRepository.get_repository())) -> ResourceHistoryInList: +async def retrieve_user_resource_history_by_user_resource_id(user_resource=Depends(get_user_resource_by_id_from_path), user=Depends(get_current_workspace_owner_or_researcher_user_or_airlock_manager), resource_history_repo=Depends(get_repository(ResourceHistoryRepository))) -> ResourceHistoryInList: validate_user_has_valid_role_for_user_resource(user, user_resource) return ResourceHistoryInList(resource_history=await resource_history_repo.get_resource_history_by_resource_id(resource_id=user_resource.id)) From bfca0cd1d32bd29c5d02f0c022ee91fc30e54da0 Mon Sep 17 00:00:00 2001 From: marrobi Date: Tue, 2 Jan 2024 15:43:23 +0000 Subject: [PATCH 23/32] reduce churn --- api_app/api/dependencies/airlock.py | 3 ++- api_app/api/dependencies/shared_services.py | 5 +++-- .../workspace_service_templates.py | 3 ++- api_app/api/dependencies/workspaces.py | 13 ++++++----- api_app/db/repositories/base.py | 19 +--------------- api_app/services/aad_authentication.py | 1 - .../test_api/test_routes/test_workspaces.py | 4 ++-- .../test_deployment_status_update.py | 8 +++---- .../test_resource_request_sender.py | 22 +++++++++---------- 9 files changed, 32 insertions(+), 46 deletions(-) diff --git a/api_app/api/dependencies/airlock.py b/api_app/api/dependencies/airlock.py index 1a8ee75994..c824352ee5 100644 --- a/api_app/api/dependencies/airlock.py +++ b/api_app/api/dependencies/airlock.py @@ -1,6 +1,7 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 +from api.helpers import get_repository from db.repositories.airlock_requests import AirlockRequestRepository from models.domain.airlock_request import AirlockRequest from db.errors import EntityDoesNotExist, UnableToAccessDatabase @@ -16,5 +17,5 @@ async def get_airlock_request_by_id(airlock_request_id: UUID4, airlock_request_r raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=strings.STATE_STORE_ENDPOINT_NOT_RESPONDING) -async def get_airlock_request_by_id_from_path(airlock_request_id: UUID4 = Path(...), airlock_request_repo=Depends(AirlockRequestRepository.get_repository())) -> AirlockRequest: +async def get_airlock_request_by_id_from_path(airlock_request_id: UUID4 = Path(...), airlock_request_repo=Depends(get_repository(AirlockRequestRepository))) -> AirlockRequest: return await get_airlock_request_by_id(airlock_request_id, airlock_request_repo) diff --git a/api_app/api/dependencies/shared_services.py b/api_app/api/dependencies/shared_services.py index f84a5e4b6a..0dc0320810 100644 --- a/api_app/api/dependencies/shared_services.py +++ b/api_app/api/dependencies/shared_services.py @@ -1,6 +1,7 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 +from api_app.api.helpers import get_repository from db.errors import EntityDoesNotExist from resources import strings from models.domain.shared_service import SharedService @@ -16,11 +17,11 @@ async def get_shared_service_by_id(shared_service_id: UUID4, shared_services_rep raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.SHARED_SERVICE_DOES_NOT_EXIST) -async def get_shared_service_by_id_from_path(shared_service_id: UUID4 = Path(...), shared_service_repo=Depends(SharedServiceRepository.get_repository())) -> SharedService: +async def get_shared_service_by_id_from_path(shared_service_id: UUID4 = Path(...), shared_service_repo=Depends(get_repository(SharedServiceRepository))) -> SharedService: return await get_shared_service_by_id(shared_service_id, shared_service_repo) -async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(OperationRepository.get_repository())) -> Operation: +async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(get_repository(OperationRepository))) -> Operation: try: return await operations_repo.get_operation_by_id(operation_id=operation_id) except EntityDoesNotExist: diff --git a/api_app/api/dependencies/workspace_service_templates.py b/api_app/api/dependencies/workspace_service_templates.py index b06864edcd..0f8231253f 100644 --- a/api_app/api/dependencies/workspace_service_templates.py +++ b/api_app/api/dependencies/workspace_service_templates.py @@ -1,5 +1,6 @@ from fastapi import Depends, HTTPException, Path, status +from api_app.api.helpers import get_repository from db.errors import EntityDoesNotExist from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType @@ -7,7 +8,7 @@ from resources import strings -async def get_workspace_service_template_by_name_from_path(service_template_name: str = Path(...), template_repo=Depends(ResourceTemplateRepository.get_repository())) -> ResourceTemplate: +async def get_workspace_service_template_by_name_from_path(service_template_name: str = Path(...), template_repo=Depends(get_repository(ResourceTemplateRepository))) -> ResourceTemplate: try: return await template_repo.get_current_template(service_template_name, ResourceType.WorkspaceService) except EntityDoesNotExist: diff --git a/api_app/api/dependencies/workspaces.py b/api_app/api/dependencies/workspaces.py index 56263b3938..f6a5b1ee19 100644 --- a/api_app/api/dependencies/workspaces.py +++ b/api_app/api/dependencies/workspaces.py @@ -1,6 +1,7 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 +from api_app.api.helpers import get_repository from db.errors import EntityDoesNotExist, ResourceIsNotDeployed from db.repositories.operations import OperationRepository from db.repositories.user_resources import UserResourceRepository @@ -21,11 +22,11 @@ async def get_workspace_by_id(workspace_id: UUID4, workspaces_repo) -> Workspace raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.WORKSPACE_DOES_NOT_EXIST) -async def get_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(WorkspaceRepository)) -> Workspace: +async def get_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(get_repository(WorkspaceRepository))) -> Workspace: return await get_workspace_by_id(workspace_id, workspaces_repo) -async def get_deployed_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(WorkspaceRepository), operations_repo=Depends(OperationRepository.get_repository())) -> Workspace: +async def get_deployed_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(get_repository(WorkspaceRepository)), operations_repo=Depends(get_repository(OperationRepository))) -> Workspace: try: return await workspaces_repo.get_deployed_workspace_by_id(workspace_id, operations_repo) except EntityDoesNotExist: @@ -34,14 +35,14 @@ async def get_deployed_workspace_by_id_from_path(workspace_id: UUID4 = Path(...) raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=strings.WORKSPACE_IS_NOT_DEPLOYED) -async def get_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(WorkspaceServiceRepository)) -> WorkspaceService: +async def get_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(get_repository(WorkspaceServiceRepository))) -> WorkspaceService: try: return await workspace_services_repo.get_workspace_service_by_id(workspace_id, service_id) except EntityDoesNotExist: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.WORKSPACE_SERVICE_DOES_NOT_EXIST) -async def get_deployed_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(WorkspaceServiceRepository), operations_repo=Depends(OperationRepository.get_repository())) -> WorkspaceService: +async def get_deployed_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(get_repository(WorkspaceServiceRepository)), operations_repo=Depends(get_repository(OperationRepository))) -> WorkspaceService: try: return await workspace_services_repo.get_deployed_workspace_service_by_id(workspace_id, service_id, operations_repo) except EntityDoesNotExist: @@ -50,14 +51,14 @@ async def get_deployed_workspace_service_by_id_from_path(workspace_id: UUID4 = P raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=strings.WORKSPACE_SERVICE_IS_NOT_DEPLOYED) -async def get_user_resource_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), resource_id: UUID4 = Path(...), user_resource_repo=Depends(UserResourceRepository)) -> UserResource: +async def get_user_resource_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), resource_id: UUID4 = Path(...), user_resource_repo=Depends(get_repository(UserResourceRepository))) -> UserResource: try: return await user_resource_repo.get_user_resource_by_id(workspace_id, service_id, resource_id) except EntityDoesNotExist: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.USER_RESOURCE_DOES_NOT_EXIST) -async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(OperationRepository.get_repository())) -> Operation: +async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(get_repository(OperationRepository))) -> Operation: try: return await operations_repo.get_operation_by_id(operation_id=operation_id) except EntityDoesNotExist: diff --git a/api_app/db/repositories/base.py b/api_app/db/repositories/base.py index 9ac8af6eb0..ae3d2a6008 100644 --- a/api_app/db/repositories/base.py +++ b/api_app/db/repositories/base.py @@ -1,13 +1,10 @@ -from typing import Callable, Optional +from typing import Optional from azure.cosmos.aio import ContainerProxy from azure.core import MatchConditions -from fastapi import HTTPException, status from pydantic import BaseModel from api.dependencies.database import Database -from resources.strings import STATE_STORE_ENDPOINT_NOT_RESPONDING from db.errors import UnableToAccessDatabase -from services.logging import logger class BaseRepository: @@ -16,20 +13,6 @@ async def create(cls, container_name: Optional[str] = None): cls._container: ContainerProxy = await cls._get_container(container_name) return cls - @classmethod - def get_repository(cls) -> Callable: - async def _get_repo() -> BaseRepository: - try: - return await cls.create() - except UnableToAccessDatabase: - logger.exception(STATE_STORE_ENDPOINT_NOT_RESPONDING) - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=STATE_STORE_ENDPOINT_NOT_RESPONDING, - ) - - return _get_repo - @property def container(self) -> ContainerProxy: return self._container diff --git a/api_app/services/aad_authentication.py b/api_app/services/aad_authentication.py index 70d8e19b45..81dd486a8f 100644 --- a/api_app/services/aad_authentication.py +++ b/api_app/services/aad_authentication.py @@ -119,7 +119,6 @@ async def _fetch_ws_app_reg_id_from_ws_id(request: Request) -> str: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=strings.AUTH_COULD_NOT_VALIDATE_CREDENTIALS) try: workspace_id = request.path_params['workspace_id'] - ws_repo = await WorkspaceRepository.create() workspace = await ws_repo.get_workspace_by_id(workspace_id) diff --git a/api_app/tests_ma/test_api/test_routes/test_workspaces.py b/api_app/tests_ma/test_api/test_routes/test_workspaces.py index f948a13195..0febe3694b 100644 --- a/api_app/tests_ma/test_api/test_routes/test_workspaces.py +++ b/api_app/tests_ma/test_api/test_routes/test_workspaces.py @@ -676,7 +676,7 @@ async def test_delete_workspace_returns_400_if_associated_workspace_services_are # [DELETE] /workspaces/{workspace_id} @ patch("api.routes.resource_helpers.ResourceRepository.get_resource_dependency_list") @ patch("api.routes.workspaces.ResourceTemplateRepository.get_template_by_name_and_version") - @ patch("db.repositories.base.BaseRepository.get_repository") + @ patch("api.helpers.get_repository") @ patch("api.dependencies.workspaces.WorkspaceRepository.get_workspace_by_id") @ patch("api.routes.workspaces.WorkspaceServiceRepository.get_active_workspace_services_for_workspace", return_value=[]) @ patch('api.routes.resource_helpers.send_resource_request_message', return_value=sample_resource_operation(resource_id=WORKSPACE_ID, operation_id=OPERATION_ID)) @@ -691,7 +691,7 @@ async def test_delete_workspace_sends_a_request_message_to_uninstall_the_workspa # [DELETE] /workspaces/{workspace_id} @ patch("api.routes.resource_helpers.ResourceRepository.get_resource_dependency_list") @ patch("api.routes.workspaces.ResourceTemplateRepository.get_template_by_name_and_version") - @ patch("db.repositories.base.BaseRepository.get_repository") + @ patch("api.helpers.get_repository") @ patch("api.dependencies.workspaces.WorkspaceRepository.get_workspace_by_id") @ patch("api.routes.workspaces.WorkspaceServiceRepository.get_active_workspace_services_for_workspace") async def test_delete_workspace_raises_503_if_marking_the_resource_as_deleted_in_the_db_fails(self, ___, get_workspace_mock, _____, resource_template_repo, ______, client, app, disabled_workspace, basic_resource_template): diff --git a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py index aa69cddd12..b013914ba2 100644 --- a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py @@ -134,9 +134,9 @@ async def test_receiving_bad_json_logs_error(cosmos_client, logging_mock, payloa @patch('service_bus.deployment_status_updater.ResourceHistoryRepository.create') -@patch('service_bus.deployment_status_updater.ResourceTemplateRepository') -@patch('service_bus.deployment_status_updater.OperationRepository') -@patch('service_bus.deployment_status_updater.ResourceRepository') +@patch('service_bus.deployment_status_updater.ResourceTemplateRepository.create') +@patch('service_bus.deployment_status_updater.OperationRepository.create') +@patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('services.logging.logger.exception') @patch("api.dependencies.database.Database.get_db_client") async def test_receiving_good_message(cosmos_client, logging_mock, resource_repo, operation_repo, _, __): @@ -296,7 +296,7 @@ async def test_properties_dont_change_with_no_outputs(cosmos_client, resource_re @patch('service_bus.deployment_status_updater.ResourceHistoryRepository.create') -@patch('service_bus.deployment_status_updater.ResourceTemplateRepository'.create) +@patch('service_bus.deployment_status_updater.ResourceTemplateRepository.create') @patch('service_bus.deployment_status_updater.update_resource_for_step') @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') diff --git a/api_app/tests_ma/test_service_bus/test_resource_request_sender.py b/api_app/tests_ma/test_service_bus/test_resource_request_sender.py index e90894791d..71e3785191 100644 --- a/api_app/tests_ma/test_service_bus/test_resource_request_sender.py +++ b/api_app/tests_ma/test_service_bus/test_resource_request_sender.py @@ -40,11 +40,11 @@ def create_test_resource(): @pytest.mark.parametrize( "request_action", [RequestAction.Install, RequestAction.UnInstall] ) -@patch("service_bus.resource_request_sender.ResourceHistoryRepository") -@patch("service_bus.resource_request_sender.OperationRepository") +@patch("service_bus.resource_request_sender.ResourceHistoryRepository.create") +@patch("service_bus.resource_request_sender.OperationRepository.create") @patch("service_bus.helpers.ServiceBusClient") -@patch("service_bus.resource_request_sender.ResourceRepository") -@patch("service_bus.resource_request_sender.ResourceTemplateRepository") +@patch("service_bus.resource_request_sender.ResourceRepository.create") +@patch("service_bus.resource_request_sender.ResourceTemplateRepository.create") async def test_resource_request_message_generated_correctly( resource_template_repo, resource_repo, @@ -84,10 +84,10 @@ async def test_resource_request_message_generated_correctly( assert sent_message_as_json["action"] == request_action -@patch("service_bus.resource_request_sender.ResourceHistoryRepository") -@patch("service_bus.resource_request_sender.OperationRepository") -@patch("service_bus.resource_request_sender.ResourceRepository") -@patch("service_bus.resource_request_sender.ResourceTemplateRepository") +@patch("service_bus.resource_request_sender.ResourceHistoryRepository.create") +@patch("service_bus.resource_request_sender.OperationRepository.create") +@patch("service_bus.resource_request_sender.ResourceRepository.create") +@patch("service_bus.resource_request_sender.ResourceTemplateRepository.create") async def test_multi_step_document_sends_first_step( resource_template_repo, resource_repo, @@ -146,9 +146,9 @@ async def test_multi_step_document_sends_first_step( ) -@patch("service_bus.resource_request_sender.ResourceHistoryRepository") -@patch("service_bus.resource_request_sender.ResourceRepository") -@patch("service_bus.resource_request_sender.ResourceTemplateRepository") +@patch("service_bus.resource_request_sender.ResourceHistoryRepository.create") +@patch("service_bus.resource_request_sender.ResourceRepository.create") +@patch("service_bus.resource_request_sender.ResourceTemplateRepository.create") async def test_multi_step_document_retries( resource_template_repo, resource_repo, From 4ff113efd11b7b21cbd691a63964eada7310b4f3 Mon Sep 17 00:00:00 2001 From: marrobi Date: Tue, 2 Jan 2024 15:48:49 +0000 Subject: [PATCH 24/32] Fix bad imports --- api_app/_version.py | 2 +- api_app/api/dependencies/shared_services.py | 2 +- api_app/api/dependencies/workspace_service_templates.py | 2 +- api_app/api/dependencies/workspaces.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api_app/_version.py b/api_app/_version.py index 391a39001a..bcea63d014 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.18.5" +__version__ = "0.18.6" diff --git a/api_app/api/dependencies/shared_services.py b/api_app/api/dependencies/shared_services.py index 0dc0320810..87bf4474cc 100644 --- a/api_app/api/dependencies/shared_services.py +++ b/api_app/api/dependencies/shared_services.py @@ -1,7 +1,7 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 -from api_app.api.helpers import get_repository +from api.helpers import get_repository from db.errors import EntityDoesNotExist from resources import strings from models.domain.shared_service import SharedService diff --git a/api_app/api/dependencies/workspace_service_templates.py b/api_app/api/dependencies/workspace_service_templates.py index 0f8231253f..56bf60e869 100644 --- a/api_app/api/dependencies/workspace_service_templates.py +++ b/api_app/api/dependencies/workspace_service_templates.py @@ -1,6 +1,6 @@ from fastapi import Depends, HTTPException, Path, status -from api_app.api.helpers import get_repository +from api.helpers import get_repository from db.errors import EntityDoesNotExist from db.repositories.resource_templates import ResourceTemplateRepository from models.domain.resource import ResourceType diff --git a/api_app/api/dependencies/workspaces.py b/api_app/api/dependencies/workspaces.py index f6a5b1ee19..aae2cc7213 100644 --- a/api_app/api/dependencies/workspaces.py +++ b/api_app/api/dependencies/workspaces.py @@ -1,7 +1,7 @@ from fastapi import Depends, HTTPException, Path, status from pydantic import UUID4 -from api_app.api.helpers import get_repository +from api.helpers import get_repository from db.errors import EntityDoesNotExist, ResourceIsNotDeployed from db.repositories.operations import OperationRepository from db.repositories.user_resources import UserResourceRepository From 8e98ebff6497591d1ce871e96a4b7b67810d4cd8 Mon Sep 17 00:00:00 2001 From: marrobi Date: Tue, 2 Jan 2024 17:23:13 +0000 Subject: [PATCH 25/32] fix reverted change --- api_app/_version.py | 2 +- api_app/db/repositories/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api_app/_version.py b/api_app/_version.py index bcea63d014..2e93a98791 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.18.6" +__version__ = "0.18.7" diff --git a/api_app/db/repositories/base.py b/api_app/db/repositories/base.py index ae3d2a6008..c73ccc78c1 100644 --- a/api_app/db/repositories/base.py +++ b/api_app/db/repositories/base.py @@ -21,7 +21,7 @@ def container(self) -> ContainerProxy: async def _get_container(cls, container_name) -> ContainerProxy: try: database = await Database().get_db_client() - container = await database.create_container_if_not_exists(id=container_name) + container = await database.get_container_client(container=container_name) return container except Exception: raise UnableToAccessDatabase From df711f6e4fd093bdd6dfe5faff212b797339a5c1 Mon Sep 17 00:00:00 2001 From: marrobi Date: Wed, 3 Jan 2024 12:58:49 +0000 Subject: [PATCH 26/32] Move get container into database class --- api_app/api/dependencies/database.py | 29 ++++++++++-------- api_app/api/helpers.py | 9 +++--- api_app/db/repositories/base.py | 15 ++++------ api_app/resources/strings.py | 1 + api_app/services/health_checker.py | 6 ++-- api_app/tests_ma/conftest.py | 22 +++++++------- api_app/tests_ma/test_api/conftest.py | 10 ------- .../test_api/dependencies/__init__.py | 0 .../test_api/dependencies/test_database.py | 12 ++++++++ api_app/tests_ma/test_api/test_helpers.py | 17 +++++++++++ .../test_routes/test_resource_helpers.py | 2 +- .../test_workspace_migration.py | 2 +- .../test_airlock_request_repository.py | 2 +- .../test_repositories/test_base_repository.py | 13 ++++---- .../test_operation_repository.py | 6 ++-- .../test_resource_history_repository.py | 2 +- .../test_resource_repository.py | 4 +-- .../test_resource_templates_repository.py | 2 +- .../test_shared_service_repository.py | 4 +-- ...est_shared_service_templates_repository.py | 2 +- .../test_user_resource_repository.py | 2 +- ...test_user_resource_templates_repository.py | 2 +- .../test_workpaces_repository.py | 4 +-- .../test_workpaces_service_repository.py | 4 +-- .../test_airlock_request_status_update.py | 18 ++++------- .../test_deployment_status_update.py | 30 +++++++------------ .../test_services/test_health_checker.py | 27 +++++++---------- 27 files changed, 123 insertions(+), 124 deletions(-) create mode 100644 api_app/tests_ma/test_api/dependencies/__init__.py create mode 100644 api_app/tests_ma/test_api/dependencies/test_database.py create mode 100644 api_app/tests_ma/test_api/test_helpers.py diff --git a/api_app/api/dependencies/database.py b/api_app/api/dependencies/database.py index a22d548f32..7bfc89ff22 100644 --- a/api_app/api/dependencies/database.py +++ b/api_app/api/dependencies/database.py @@ -1,7 +1,7 @@ -from azure.cosmos.aio import CosmosClient +from azure.cosmos.aio import CosmosClient, DatabaseProxy, ContainerProxy from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient -from core.config import MANAGED_IDENTITY_CLIENT_ID, STATE_STORE_ENDPOINT, STATE_STORE_KEY, STATE_STORE_SSL_VERIFY, SUBSCRIPTION_ID, RESOURCE_MANAGER_ENDPOINT, CREDENTIAL_SCOPES, RESOURCE_GROUP_NAME, COSMOSDB_ACCOUNT_NAME +from core.config import MANAGED_IDENTITY_CLIENT_ID, STATE_STORE_ENDPOINT, STATE_STORE_KEY, STATE_STORE_SSL_VERIFY, SUBSCRIPTION_ID, RESOURCE_MANAGER_ENDPOINT, CREDENTIAL_SCOPES, RESOURCE_GROUP_NAME, COSMOSDB_ACCOUNT_NAME, STATE_STORE_DATABASE from core.credentials import get_credential_async from services.logging import logger @@ -16,13 +16,15 @@ def __call__(cls, *args, **kwargs): class Database(metaclass=Singleton): - cosmos_client = None - def __init__(self): + _cosmos_client: CosmosClient = None + _database_proxy: DatabaseProxy = None + + def __init__(cls): pass @classmethod - async def _connect_to_db(self) -> CosmosClient: + async def _connect_to_db(cls) -> CosmosClient: logger.debug(f"Connecting to {STATE_STORE_ENDPOINT}") credential = await get_credential_async() @@ -34,7 +36,7 @@ async def _connect_to_db(self) -> CosmosClient: ) else: logger.debug("Connecting with key") - primary_master_key = await self._get_store_key(credential) + primary_master_key = await cls._get_store_key(credential) if STATE_STORE_SSL_VERIFY: logger.debug("Connecting with SSL verification") @@ -54,7 +56,7 @@ async def _connect_to_db(self) -> CosmosClient: return cosmos_client @classmethod - async def _get_store_key(self, credential) -> str: + async def _get_store_key(cls, credential) -> str: logger.debug("Getting store key") if STATE_STORE_KEY: primary_master_key = STATE_STORE_KEY @@ -74,8 +76,11 @@ async def _get_store_key(self, credential) -> str: return primary_master_key @classmethod - async def get_db_client(self) -> CosmosClient: - logger.debug("Getting cosmos client") - if not Database.cosmos_client: - Database.cosmos_client = await self._connect_to_db() - return self.cosmos_client + async def get_container_proxy(cls, container_name) -> ContainerProxy: + if cls._cosmos_client is None: + cls._cosmos_client = await cls._connect_to_db() + + if cls._database_proxy is None: + cls._database_proxy = cls._cosmos_client.get_database_client(STATE_STORE_DATABASE) + + return cls._database_proxy.get_container_client(container_name) diff --git a/api_app/api/helpers.py b/api_app/api/helpers.py index 8cd137d59a..1c2d3a0529 100644 --- a/api_app/api/helpers.py +++ b/api_app/api/helpers.py @@ -1,10 +1,11 @@ from typing import Callable, Type -from fastapi import HTTPException, logger, status +from fastapi import HTTPException, status from db.errors import UnableToAccessDatabase from db.repositories.base import BaseRepository -from resources.strings import STATE_STORE_ENDPOINT_NOT_RESPONDING +from resources.strings import UNABLE_TO_GET_STATE_STORE_CLIENT +from services.logging import logger def get_repository(repo_type: Type[BaseRepository],) -> Callable: @@ -12,10 +13,10 @@ async def _get_repo() -> BaseRepository: try: return await repo_type.create() except UnableToAccessDatabase: - logger.exception(STATE_STORE_ENDPOINT_NOT_RESPONDING) + logger.exception(UNABLE_TO_GET_STATE_STORE_CLIENT) raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=STATE_STORE_ENDPOINT_NOT_RESPONDING, + detail=UNABLE_TO_GET_STATE_STORE_CLIENT, ) return _get_repo diff --git a/api_app/db/repositories/base.py b/api_app/db/repositories/base.py index c73ccc78c1..7fe5371b5c 100644 --- a/api_app/db/repositories/base.py +++ b/api_app/db/repositories/base.py @@ -10,22 +10,17 @@ class BaseRepository: @classmethod async def create(cls, container_name: Optional[str] = None): - cls._container: ContainerProxy = await cls._get_container(container_name) + try: + cls._container: ContainerProxy = await Database().get_container_proxy(container_name) + except Exception: + raise UnableToAccessDatabase + return cls @property def container(self) -> ContainerProxy: return self._container - @classmethod - async def _get_container(cls, container_name) -> ContainerProxy: - try: - database = await Database().get_db_client() - container = await database.get_container_client(container=container_name) - return container - except Exception: - raise UnableToAccessDatabase - async def query(self, query: str, parameters: Optional[dict] = None): items = self.container.query_items(query=query, parameters=parameters) return [i async for i in items] diff --git a/api_app/resources/strings.py b/api_app/resources/strings.py index 0c78bd7850..9c2d7ff4b4 100644 --- a/api_app/resources/strings.py +++ b/api_app/resources/strings.py @@ -82,6 +82,7 @@ OK = "OK" NOT_OK = "Not OK" COSMOS_DB = "Cosmos DB" +UNABLE_TO_GET_STATE_STORE_CLIENT = "Unable to get state store client" STATE_STORE_ENDPOINT_NOT_RESPONDING = "State Store endpoint is not responding" STATE_STORE_ENDPOINT_NOT_ACCESSIBLE = "State Store endpoint is not accessible" UNSPECIFIED_ERROR = "Unspecified error" diff --git a/api_app/services/health_checker.py b/api_app/services/health_checker.py index 6bafcbd499..afff83c713 100644 --- a/api_app/services/health_checker.py +++ b/api_app/services/health_checker.py @@ -5,6 +5,7 @@ from azure.cosmos.exceptions import CosmosHttpResponseError from azure.servicebus.exceptions import ServiceBusConnectionError, ServiceBusAuthenticationError from api.dependencies.database import Database +from core.config import STATE_STORE_RESOURCES_CONTAINER from core import config from models.schemas.status import StatusEnum @@ -16,9 +17,8 @@ async def create_state_store_status() -> Tuple[StatusEnum, str]: status = StatusEnum.ok message = "" try: - cosmos_client = await Database().get_db_client() - list_databases_response = cosmos_client.list_databases() - [database async for database in list_databases_response] + container = await Database().get_container_proxy(STATE_STORE_RESOURCES_CONTAINER) + await container.query_items("SELECT TOP 1 * FROM c") except exceptions.ServiceRequestError: status = StatusEnum.not_ok message = strings.STATE_STORE_ENDPOINT_NOT_RESPONDING diff --git a/api_app/tests_ma/conftest.py b/api_app/tests_ma/conftest.py index ccbb44ebfc..0bd06e076d 100644 --- a/api_app/tests_ma/conftest.py +++ b/api_app/tests_ma/conftest.py @@ -1,6 +1,9 @@ import pytest import pytest_asyncio -from mock import patch +from mock import AsyncMock, patch +from azure.cosmos.aio import CosmosClient, DatabaseProxy + +from api.dependencies.database import Database from models.domain.request_action import RequestAction from models.domain.resource import Resource from models.domain.user_resource import UserResource @@ -572,13 +575,10 @@ def simple_pipeline_step() -> PipelineStep: ) -@pytest_asyncio.fixture() -def no_database(): - """overrides connecting to the database""" - with patch("api.dependencies.database.Database._connect_to_db", return_value=None): - with patch("api.dependencies.database.Database.get_db_client", return_value=None): - with patch( - "db.repositories.base.BaseRepository._get_container", return_value=None - ): - with patch("db.events.bootstrap_database", return_value=None): - yield +@pytest_asyncio.fixture(autouse=True) +async def no_database(): + with patch('api.dependencies.database.get_credential_async', return_value=AsyncMock()), \ + patch('api.dependencies.database.CosmosDBManagementClient', return_value=AsyncMock()), \ + patch('api.dependencies.database.CosmosClient', return_value=AsyncMock(spec=CosmosClient)) as cosmos_client_mock: + cosmos_client_mock.return_value.get_database_client.return_value = AsyncMock(spec=DatabaseProxy) + yield Database() diff --git a/api_app/tests_ma/test_api/conftest.py b/api_app/tests_ma/test_api/conftest.py index a22a6080e0..e781cf854e 100644 --- a/api_app/tests_ma/test_api/conftest.py +++ b/api_app/tests_ma/test_api/conftest.py @@ -14,16 +14,6 @@ def no_lifespan_events(): yield -@pytest_asyncio.fixture(autouse=True) -def no_database(): - """ overrides connecting to the database for all tests""" - with patch('api.dependencies.database.Database._connect_to_db', return_value=None): - with patch('api.dependencies.database.Database.get_db_client', return_value=None): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): - with patch('db.events.bootstrap_database', return_value=None): - yield - - @pytest.fixture(autouse=True) def no_auth_token(): """ overrides validating and decoding tokens for all tests""" diff --git a/api_app/tests_ma/test_api/dependencies/__init__.py b/api_app/tests_ma/test_api/dependencies/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api_app/tests_ma/test_api/dependencies/test_database.py b/api_app/tests_ma/test_api/dependencies/test_database.py new file mode 100644 index 0000000000..a1faefb453 --- /dev/null +++ b/api_app/tests_ma/test_api/dependencies/test_database.py @@ -0,0 +1,12 @@ +from mock import MagicMock +import pytest + +from api.dependencies.database import Database + +pytestmark = pytest.mark.asyncio + + +async def test_get_container_proxy(): + container_name = "test_container" + container_proxy = await Database().get_container_proxy(container_name) + assert isinstance(container_proxy, MagicMock) diff --git a/api_app/tests_ma/test_api/test_helpers.py b/api_app/tests_ma/test_api/test_helpers.py new file mode 100644 index 0000000000..c315a26346 --- /dev/null +++ b/api_app/tests_ma/test_api/test_helpers.py @@ -0,0 +1,17 @@ +import pytest +from mock import patch +from fastapi import HTTPException + +from db.errors import UnableToAccessDatabase +from db.repositories.base import BaseRepository +from api.helpers import get_repository + +pytestmark = pytest.mark.asyncio + + +@patch("db.repositories.base.BaseRepository.create") +async def test_get_repository_raises_http_exception_when_unable_to_access_database(create_base_repo_mock): + create_base_repo_mock.side_effect = UnableToAccessDatabase() + with pytest.raises(HTTPException): + get_repo = get_repository(BaseRepository) + await get_repo() diff --git a/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py b/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py index b7f61a6345..31fb31e5f1 100644 --- a/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py +++ b/api_app/tests_ma/test_api/test_routes/test_resource_helpers.py @@ -29,7 +29,7 @@ @pytest_asyncio.fixture async def resource_repo() -> ResourceRepository: - with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=AsyncMock()): resource_repo_mock = await ResourceRepository().create() yield resource_repo_mock diff --git a/api_app/tests_ma/test_db/test_migrations/test_workspace_migration.py b/api_app/tests_ma/test_db/test_migrations/test_workspace_migration.py index 5c4bfaaa7d..3da577094b 100644 --- a/api_app/tests_ma/test_db/test_migrations/test_workspace_migration.py +++ b/api_app/tests_ma/test_db/test_migrations/test_workspace_migration.py @@ -11,7 +11,7 @@ @pytest_asyncio.fixture async def workspace_migrator(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=AsyncMock()): workspace_migrator = await WorkspaceMigration.create() yield workspace_migrator diff --git a/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py b/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py index 0b91bb59f4..4c773db327 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py @@ -46,7 +46,7 @@ @pytest_asyncio.fixture async def airlock_request_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=AsyncMock()): airlock_request_repo_mock = await AirlockRequestRepository.create() yield airlock_request_repo_mock diff --git a/api_app/tests_ma/test_db/test_repositories/test_base_repository.py b/api_app/tests_ma/test_db/test_repositories/test_base_repository.py index 09c7cbf0bf..3d89923e24 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_base_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_base_repository.py @@ -1,6 +1,5 @@ import pytest - -from mock import patch, MagicMock +from mock import patch from db.errors import UnableToAccessDatabase from db.repositories.base import BaseRepository @@ -8,8 +7,8 @@ pytestmark = pytest.mark.asyncio -async def test_instantiating_a_repo_raises_unable_to_access_database_if_database_cant_be_accessed(): - with patch("azure.cosmos.CosmosClient") as cosmos_client_mock: - cosmos_client_mock.create_container_if_not_exists = MagicMock(side_effect=Exception) - with pytest.raises(UnableToAccessDatabase): - await BaseRepository.create("container") +@patch("api.dependencies.database.Database.get_container_proxy") +async def test_instantiating_a_repo_raises_unable_to_access_database_if_database_cant_be_accessed(get_container_proxy_mock): + get_container_proxy_mock.side_effect = Exception() + with pytest.raises(UnableToAccessDatabase): + await BaseRepository.create() diff --git a/api_app/tests_ma/test_db/test_repositories/test_operation_repository.py b/api_app/tests_ma/test_db/test_repositories/test_operation_repository.py index 43cdf3c4f1..91e51fba0a 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_operation_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_operation_repository.py @@ -18,21 +18,21 @@ @pytest_asyncio.fixture async def operations_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): operations_repo = await OperationRepository.create() yield operations_repo @pytest_asyncio.fixture async def resource_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): resource_repo = await ResourceRepository.create() yield resource_repo @pytest_asyncio.fixture async def resource_template_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): resource_template_repo = await ResourceTemplateRepository.create() yield resource_template_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py b/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py index 4439911752..f4f37093c7 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_resource_history_repository.py @@ -16,7 +16,7 @@ @pytest_asyncio.fixture async def resource_history_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): resource_history_repo = await ResourceHistoryRepository().create() yield resource_history_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py b/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py index c929469017..a308b7888f 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_resource_repository.py @@ -26,14 +26,14 @@ @pytest_asyncio.fixture async def resource_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): resource_repo = await ResourceRepository().create() yield resource_repo @pytest_asyncio.fixture async def resource_history_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): resource_history_repo = await ResourceHistoryRepository().create() yield resource_history_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py b/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py index 20fa5f30fe..d007326323 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_resource_templates_repository.py @@ -14,7 +14,7 @@ @pytest_asyncio.fixture async def resource_template_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): resource_template_repo = await ResourceTemplateRepository().create() yield resource_template_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py b/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py index 95014a53f7..131b7c9d78 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_shared_service_repository.py @@ -17,14 +17,14 @@ @pytest_asyncio.fixture async def shared_service_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=AsyncMock()): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=AsyncMock()): shared_service_repo = await SharedServiceRepository().create() yield shared_service_repo @pytest_asyncio.fixture async def operations_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): operations_repo = await OperationRepository().create() yield operations_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py b/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py index e9d095bf01..226777f6f4 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_shared_service_templates_repository.py @@ -11,7 +11,7 @@ @pytest_asyncio.fixture async def resource_template_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): resource_template_repo = await ResourceTemplateRepository().create() yield resource_template_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py b/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py index 034b8155b3..90067c3356 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_user_resource_repository.py @@ -26,7 +26,7 @@ def basic_user_resource_request(): @pytest_asyncio.fixture async def user_resource_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): user_resource_repo = await UserResourceRepository().create() yield user_resource_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py b/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py index 01ec7a8116..727c9a2d18 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_user_resource_templates_repository.py @@ -13,7 +13,7 @@ @pytest_asyncio.fixture async def resource_template_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=None): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=None): resource_template_repo = await ResourceTemplateRepository().create() yield resource_template_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py b/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py index 5ad523ccb5..01727d4e24 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_workpaces_repository.py @@ -20,14 +20,14 @@ def basic_workspace_request(): @pytest_asyncio.fixture async def workspace_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=MagicMock()): workspace_repo = await WorkspaceRepository().create() yield workspace_repo @pytest_asyncio.fixture async def operations_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=MagicMock()): operations_repo = await OperationRepository().create() yield operations_repo diff --git a/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py b/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py index 80909411f5..390c0e18d2 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_workpaces_service_repository.py @@ -18,14 +18,14 @@ @pytest_asyncio.fixture async def workspace_service_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=MagicMock()): workspace_repo = await WorkspaceServiceRepository().create() yield workspace_repo @pytest_asyncio.fixture async def operations_repo(): - with patch('db.repositories.base.BaseRepository._get_container', return_value=MagicMock()): + with patch('api.dependencies.database.Database.get_container_proxy', return_value=MagicMock()): operations_repo = await OperationRepository().create() yield operations_repo diff --git a/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py b/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py index 76aed496c7..d9165db2f5 100644 --- a/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py @@ -108,9 +108,8 @@ def __str__(self): @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('service_bus.airlock_request_status_update.WorkspaceRepository.create') @patch('logging.exception') -@patch("api.dependencies.database.Database.get_db_client") @patch("services.aad_authentication.AzureADAuthorization.get_workspace_role_assignment_details", return_value={"researcher_emails": ["researcher@outlook.com"], "owner_emails": ["owner@outlook.com"]}) -async def test_receiving_good_message(_, cosmos_client, logging_mock, workspace_repo, airlock_request_repo, eg_client): +async def test_receiving_good_message(_, logging_mock, workspace_repo, airlock_request_repo, eg_client): eg_client().send = AsyncMock() expected_airlock_request = sample_airlock_request() @@ -140,8 +139,7 @@ async def test_receiving_good_message(_, cosmos_client, logging_mock, workspace_ @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('service_bus.airlock_request_status_update.WorkspaceRepository.create') @patch('services.logging.logger.exception') -@patch("api.dependencies.database.Database.get_db_client") -async def test_receiving_bad_json_logs_error(cosmos_client, logging_mock, workspace_repo, airlock_request_repo, payload): +async def test_receiving_bad_json_logs_error(logging_mock, workspace_repo, airlock_request_repo, payload): service_bus_received_message_mock = ServiceBusReceivedMessageMock(payload) airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() @@ -156,8 +154,7 @@ async def test_receiving_bad_json_logs_error(cosmos_client, logging_mock, worksp @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('services.logging.logger.exception') @patch('service_bus.airlock_request_status_update.ServiceBusClient') -@patch("api.dependencies.database.Database.get_db_client") -async def test_updating_non_existent_airlock_request_error_is_logged(cosmos_client, sb_client, logging_mock, airlock_request_repo, _): +async def test_updating_non_existent_airlock_request_error_is_logged(sb_client, logging_mock, airlock_request_repo, _): service_bus_received_message_mock = ServiceBusReceivedMessageMock(test_sb_step_result_message) airlock_request_repo.return_value.get_airlock_request_by_id.side_effect = EntityDoesNotExist @@ -173,8 +170,7 @@ async def test_updating_non_existent_airlock_request_error_is_logged(cosmos_clie @patch('service_bus.airlock_request_status_update.WorkspaceRepository.create') @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('services.logging.logger.exception') -@patch("api.dependencies.database.Database.get_db_client") -async def test_when_updating_and_state_store_exception_error_is_logged(cosmos_client, logging_mock, airlock_request_repo, _): +async def test_when_updating_and_state_store_exception_error_is_logged(logging_mock, airlock_request_repo, _): service_bus_received_message_mock = ServiceBusReceivedMessageMock(test_sb_step_result_message) airlock_request_repo.return_value.get_airlock_request_by_id.side_effect = Exception @@ -189,8 +185,7 @@ async def test_when_updating_and_state_store_exception_error_is_logged(cosmos_cl @patch('service_bus.airlock_request_status_update.WorkspaceRepository.create') @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('services.logging.logger.error') -@patch("api.dependencies.database.Database.get_db_client") -async def test_when_updating_and_current_status_differs_from_status_in_state_store_error_is_logged(cosmos_client, logging_mock, airlock_request_repo, _): +async def test_when_updating_and_current_status_differs_from_status_in_state_store_error_is_logged(logging_mock, airlock_request_repo, _): service_bus_received_message_mock = ServiceBusReceivedMessageMock(test_sb_step_result_message) expected_airlock_request = sample_airlock_request(AirlockRequestStatus.Draft) @@ -208,8 +203,7 @@ async def test_when_updating_and_current_status_differs_from_status_in_state_sto @patch('service_bus.airlock_request_status_update.AirlockRequestRepository.create') @patch('services.logging.logger.exception') @patch('service_bus.airlock_request_status_update.ServiceBusClient') -@patch("api.dependencies.database.Database.get_db_client") -async def test_when_updating_and_status_update_is_illegal_error_is_logged(cosmos_client, sb_client, logging_mock, airlock_request_repo, _): +async def test_when_updating_and_status_update_is_illegal_error_is_logged(sb_client, logging_mock, airlock_request_repo, _): service_bus_received_message_mock = ServiceBusReceivedMessageMock(test_sb_step_result_message_with_invalid_status) airlock_request_repo.return_value.get_airlock_request_by_id.side_effect = HTTPException(status_code=status.HTTP_400_BAD_REQUEST) diff --git a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py index b013914ba2..db80c5b1f7 100644 --- a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py @@ -118,8 +118,7 @@ def create_sample_operation(resource_id, request_action): @pytest.mark.parametrize("payload", test_data) @patch('services.logging.logger.exception') -@patch("api.dependencies.database.Database.get_db_client") -async def test_receiving_bad_json_logs_error(cosmos_client, logging_mock, payload): +async def test_receiving_bad_json_logs_error(logging_mock, payload): service_bus_received_message_mock = ServiceBusReceivedMessageMock(payload) status_updater = DeploymentStatusUpdater() @@ -138,8 +137,7 @@ async def test_receiving_bad_json_logs_error(cosmos_client, logging_mock, payloa @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('services.logging.logger.exception') -@patch("api.dependencies.database.Database.get_db_client") -async def test_receiving_good_message(cosmos_client, logging_mock, resource_repo, operation_repo, _, __): +async def test_receiving_good_message(logging_mock, resource_repo, operation_repo, _, __): expected_workspace = create_sample_workspace_object(test_sb_message["id"]) resource_repo.return_value.get_resource_dict_by_id.return_value = expected_workspace.dict() @@ -161,8 +159,7 @@ async def test_receiving_good_message(cosmos_client, logging_mock, resource_repo @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('services.logging.logger.exception') -@patch("api.dependencies.database.Database.get_db_client") -async def test_when_updating_non_existent_workspace_error_is_logged(cosmos_client, logging_mock, resource_repo, operation_repo, _, __): +async def test_when_updating_non_existent_workspace_error_is_logged(logging_mock, resource_repo, operation_repo, _, __): resource_repo.return_value.get_resource_dict_by_id.side_effect = EntityDoesNotExist operation = create_sample_operation(test_sb_message["id"], RequestAction.Install) @@ -182,8 +179,7 @@ async def test_when_updating_non_existent_workspace_error_is_logged(cosmos_clien @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('services.logging.logger.exception') -@patch("api.dependencies.database.Database.get_db_client") -async def test_when_updating_and_state_store_exception(cosmos_client, logging_mock, resource_repo, operation_repo, _, __): +async def test_when_updating_and_state_store_exception(logging_mock, resource_repo, operation_repo, _, __): resource_repo.return_value.get_resource_dict_by_id.side_effect = Exception operation = create_sample_operation(test_sb_message["id"], RequestAction.Install) @@ -202,8 +198,7 @@ async def test_when_updating_and_state_store_exception(cosmos_client, logging_mo @patch("service_bus.deployment_status_updater.get_timestamp", return_value=FAKE_UPDATE_TIMESTAMP) @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') -@patch("api.dependencies.database.Database.get_db_client") -async def test_state_transitions_from_deployed_to_deleted(cosmos_client, resource_repo, operations_repo_mock, _, __, ___): +async def test_state_transitions_from_deployed_to_deleted(resource_repo, operations_repo_mock, _, __, ___): updated_message = test_sb_message updated_message["status"] = Status.Deleted updated_message["message"] = "Has been deleted" @@ -234,8 +229,7 @@ async def test_state_transitions_from_deployed_to_deleted(cosmos_client, resourc @patch('service_bus.deployment_status_updater.ResourceTemplateRepository.create') @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') -@patch("api.dependencies.database.Database.get_db_client") -async def test_outputs_are_added_to_resource_item(cosmos_client, resource_repo, operations_repo, _, __): +async def test_outputs_are_added_to_resource_item(resource_repo, operations_repo, _, __): received_message = test_sb_message_with_outputs received_message["status"] = Status.Deployed service_bus_received_message_mock = ServiceBusReceivedMessageMock(received_message) @@ -272,8 +266,7 @@ async def test_outputs_are_added_to_resource_item(cosmos_client, resource_repo, @patch('service_bus.deployment_status_updater.ResourceTemplateRepository.create') @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') -@patch("api.dependencies.database.Database.get_db_client") -async def test_properties_dont_change_with_no_outputs(cosmos_client, resource_repo, operations_repo, _, __): +async def test_properties_dont_change_with_no_outputs(resource_repo, operations_repo, _, __): received_message = test_sb_message received_message["status"] = Status.Deployed service_bus_received_message_mock = ServiceBusReceivedMessageMock(received_message) @@ -301,8 +294,7 @@ async def test_properties_dont_change_with_no_outputs(cosmos_client, resource_re @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('service_bus.helpers.ServiceBusClient') -@patch("api.dependencies.database.Database.get_db_client") -async def test_multi_step_operation_sends_next_step(cosmos_client, sb_sender_client, resource_repo, operations_repo, update_resource_for_step, _, __, multi_step_operation, user_resource_multi, basic_shared_service): +async def test_multi_step_operation_sends_next_step(sb_sender_client, resource_repo, operations_repo, update_resource_for_step, _, __, multi_step_operation, user_resource_multi, basic_shared_service): received_message = test_sb_message_multi_step_1_complete received_message["status"] = Status.Updated service_bus_received_message_mock = ServiceBusReceivedMessageMock(received_message) @@ -356,8 +348,7 @@ async def test_multi_step_operation_sends_next_step(cosmos_client, sb_sender_cli @patch('service_bus.deployment_status_updater.OperationRepository.create') @patch('service_bus.deployment_status_updater.ResourceRepository.create') @patch('service_bus.helpers.ServiceBusClient') -@patch("api.dependencies.database.Database.get_db_client") -async def test_multi_step_operation_ends_at_last_step(cosmos_client, sb_sender_client, resource_repo, operations_repo, _, __, multi_step_operation, user_resource_multi, basic_shared_service): +async def test_multi_step_operation_ends_at_last_step(sb_sender_client, resource_repo, operations_repo, _, __, multi_step_operation, user_resource_multi, basic_shared_service): received_message = test_sb_message_multi_step_3_complete received_message["status"] = Status.Updated service_bus_received_message_mock = ServiceBusReceivedMessageMock(received_message) @@ -401,8 +392,7 @@ async def test_multi_step_operation_ends_at_last_step(cosmos_client, sb_sender_c sb_sender_client().get_queue_sender().send_messages.assert_not_called() -@patch("api.dependencies.database.Database.get_db_client") -async def test_convert_outputs_to_dict(cosmos_client): +async def test_convert_outputs_to_dict(): # Test case 1: Empty list of outputs outputs_list = [] expected_result = {} diff --git a/api_app/tests_ma/test_services/test_health_checker.py b/api_app/tests_ma/test_services/test_health_checker.py index d918b5c87e..4d053d7e80 100644 --- a/api_app/tests_ma/test_services/test_health_checker.py +++ b/api_app/tests_ma/test_services/test_health_checker.py @@ -11,34 +11,29 @@ pytestmark = pytest.mark.asyncio -@patch("api.dependencies.database.Database._get_store_key") -@patch("api.dependencies.database.Database.cosmos_client") -async def test_get_state_store_status_responding(_, get_store_key_mock) -> None: - get_store_key_mock.return_value = None +@patch("api.dependencies.database.Database.get_container_proxy", return_value=AsyncMock()) +async def test_get_state_store_status_responding(_) -> None: + # get_store_key_mock.return_value = None status, message = await health_checker.create_state_store_status() assert status == StatusEnum.ok assert message == "" -@patch("api.dependencies.database.Database._get_store_key") -@patch("api.dependencies.database.Database.get_db_client") -async def test_get_state_store_status_not_responding(cosmos_client_mock, get_store_key_mock) -> None: - get_store_key_mock.return_value = None - cosmos_client_mock.return_value = None - cosmos_client_mock.side_effect = ServiceRequestError(message="some message") +@patch("api.dependencies.database.Database.get_container_proxy") +async def test_get_state_store_status_not_responding(container_proxy_mock) -> None: + container_proxy_mock.return_value = None + container_proxy_mock.side_effect = ServiceRequestError(message="some message") status, message = await health_checker.create_state_store_status() assert status == StatusEnum.not_ok assert message == strings.STATE_STORE_ENDPOINT_NOT_RESPONDING -@patch("api.dependencies.database.Database._get_store_key") -@patch("api.dependencies.database.Database.get_db_client") -async def test_get_state_store_status_other_exception(cosmos_client_mock, get_store_key_mock) -> None: - get_store_key_mock.return_value = None - cosmos_client_mock.return_value = None - cosmos_client_mock.side_effect = Exception() +@patch("api.dependencies.database.Database.get_container_proxy") +async def test_get_state_store_status_other_exception(container_proxy_mock) -> None: + container_proxy_mock.return_value = None + container_proxy_mock.side_effect = Exception() status, message = await health_checker.create_state_store_status() assert status == StatusEnum.not_ok From f82fd0690b588f6d11c452e5647eb6128194545c Mon Sep 17 00:00:00 2001 From: marrobi Date: Wed, 3 Jan 2024 14:02:30 +0000 Subject: [PATCH 27/32] Remove await --- api_app/_version.py | 2 +- api_app/services/health_checker.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api_app/_version.py b/api_app/_version.py index 2e93a98791..782e3ece69 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.18.7" +__version__ = "0.18.8" diff --git a/api_app/services/health_checker.py b/api_app/services/health_checker.py index afff83c713..ceb384baa3 100644 --- a/api_app/services/health_checker.py +++ b/api_app/services/health_checker.py @@ -18,7 +18,7 @@ async def create_state_store_status() -> Tuple[StatusEnum, str]: message = "" try: container = await Database().get_container_proxy(STATE_STORE_RESOURCES_CONTAINER) - await container.query_items("SELECT TOP 1 * FROM c") + container.query_items("SELECT TOP 1 * FROM c") except exceptions.ServiceRequestError: status = StatusEnum.not_ok message = strings.STATE_STORE_ENDPOINT_NOT_RESPONDING From 7dfc153c390548e32aa200b9ac3aab761a196dde Mon Sep 17 00:00:00 2001 From: marrobi Date: Wed, 3 Jan 2024 15:35:27 +0000 Subject: [PATCH 28/32] tidy up credentials modifications --- api_app/api/routes/health.py | 2 +- api_app/core/credentials.py | 39 ++++---------- api_app/event_grid/helpers.py | 2 +- .../airlock_request_status_update.py | 2 +- .../service_bus/deployment_status_updater.py | 2 +- api_app/service_bus/helpers.py | 2 +- api_app/services/health_checker.py | 3 +- api_app/tests_ma/test_db/test_events.py | 8 +-- .../test_services/test_health_checker.py | 51 +++++++++---------- 9 files changed, 45 insertions(+), 66 deletions(-) diff --git a/api_app/api/routes/health.py b/api_app/api/routes/health.py index 9d0ef42c12..2cefe21266 100644 --- a/api_app/api/routes/health.py +++ b/api_app/api/routes/health.py @@ -14,7 +14,7 @@ async def health_check(request: Request) -> HealthCheck: # The health endpoint checks the status of key components of the system. # Note that Resource Processor checks incur Azure management calls, so # calling this endpoint frequently may result in API throttling. - async with credentials.get_credential_async_cm() as credential: + async with credentials.get_credential_async_context() as credential: cosmos, sb, rp = await asyncio.gather( create_state_store_status(), create_service_bus_status(credential), diff --git a/api_app/core/credentials.py b/api_app/core/credentials.py index 331a44423c..b05ad1bf17 100644 --- a/api_app/core/credentials.py +++ b/api_app/core/credentials.py @@ -1,5 +1,5 @@ from contextlib import asynccontextmanager -from core import config +from core.config import MANAGED_IDENTITY_CLIENT_ID, AAD_AUTHORITY_URL from azure.core.credentials import TokenCredential from urllib.parse import urlparse @@ -16,13 +16,12 @@ def get_credential() -> TokenCredential: - managed_identity = config.MANAGED_IDENTITY_CLIENT_ID - if managed_identity: + if MANAGED_IDENTITY_CLIENT_ID: return ChainedTokenCredential( - ManagedIdentityCredential(client_id=managed_identity) + ManagedIdentityCredential(client_id=MANAGED_IDENTITY_CLIENT_ID) ) else: - return DefaultAzureCredential(authority=urlparse(config.AAD_AUTHORITY_URL).netloc, + return DefaultAzureCredential(authority=urlparse(AAD_AUTHORITY_URL).netloc, exclude_shared_token_cache_credential=True, exclude_workload_identity_credential=True, exclude_developer_cli_credential=True, @@ -30,18 +29,13 @@ def get_credential() -> TokenCredential: exclude_powershell_credential=True ) - -async def get_credential_async() -> TokenCredential: - """ - Context manager which yields the default credentials. - """ - managed_identity = config.MANAGED_IDENTITY_CLIENT_ID - credential = ( +async def get_credential_async(managed_identity): + return ( ChainedTokenCredentialASync( ManagedIdentityCredentialASync(client_id=managed_identity) ) if managed_identity - else DefaultAzureCredentialASync(authority=urlparse(config.AAD_AUTHORITY_URL).netloc, + else DefaultAzureCredentialASync(authority=urlparse(AAD_AUTHORITY_URL).netloc, exclude_shared_token_cache_credential=True, exclude_workload_identity_credential=True, exclude_developer_cli_credential=True, @@ -49,27 +43,12 @@ async def get_credential_async() -> TokenCredential: exclude_powershell_credential=True ) ) - return credential - @asynccontextmanager -async def get_credential_async_cm() -> TokenCredential: +async def get_credential_async_context() -> TokenCredential: """ Context manager which yields the default credentials. """ - managed_identity = config.MANAGED_IDENTITY_CLIENT_ID - credential = ( - ChainedTokenCredentialASync( - ManagedIdentityCredentialASync(client_id=managed_identity) - ) - if managed_identity - else DefaultAzureCredentialASync(authority=urlparse(config.AAD_AUTHORITY_URL).netloc, - exclude_shared_token_cache_credential=True, - exclude_workload_identity_credential=True, - exclude_developer_cli_credential=True, - exclude_managed_identity_credential=True, - exclude_powershell_credential=True - ) - ) + credential = await get_credential_async(MANAGED_IDENTITY_CLIENT_ID) yield credential await credential.close() diff --git a/api_app/event_grid/helpers.py b/api_app/event_grid/helpers.py index debfcfdfc3..bcad3e65e1 100644 --- a/api_app/event_grid/helpers.py +++ b/api_app/event_grid/helpers.py @@ -4,7 +4,7 @@ async def publish_event(event: EventGridEvent, topic_endpoint: str): - async with credentials.get_credential_async_cm() as credential: + async with credentials.get_credential_async_context() as credential: client = EventGridPublisherClient(topic_endpoint, credential) async with client: await client.send([event]) diff --git a/api_app/service_bus/airlock_request_status_update.py b/api_app/service_bus/airlock_request_status_update.py index a01156bae9..637e29c036 100644 --- a/api_app/service_bus/airlock_request_status_update.py +++ b/api_app/service_bus/airlock_request_status_update.py @@ -30,7 +30,7 @@ async def receive_messages(self): with tracer.start_as_current_span("airlock_receive_messages"): while True: try: - async with credentials.get_credential_async_cm() as credential: + async with credentials.get_credential_async_context() as credential: service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential) receiver = service_bus_client.get_queue_receiver(queue_name=config.SERVICE_BUS_STEP_RESULT_QUEUE) logger.info(f"Looking for new messages on {config.SERVICE_BUS_STEP_RESULT_QUEUE} queue...") diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 08b39d6b4e..4bac477754 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -39,7 +39,7 @@ async def receive_messages(self): with tracer.start_as_current_span("deployment_status_receive_messages"): while True: try: - async with credentials.get_credential_async_cm() as credential: + async with credentials.get_credential_async_context() as credential: service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential) logger.info(f"Looking for new messages on {config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue...") diff --git a/api_app/service_bus/helpers.py b/api_app/service_bus/helpers.py index 79f11b687b..77f7127999 100644 --- a/api_app/service_bus/helpers.py +++ b/api_app/service_bus/helpers.py @@ -24,7 +24,7 @@ async def _send_message(message: ServiceBusMessage, queue: str): :param queue: The Service Bus queue to send the message to. :type queue: str """ - async with credentials.get_credential_async_cm() as credential: + async with credentials.get_credential_async_context() as credential: service_bus_client = ServiceBusClient(config.SERVICE_BUS_FULLY_QUALIFIED_NAMESPACE, credential) async with service_bus_client: diff --git a/api_app/services/health_checker.py b/api_app/services/health_checker.py index ceb384baa3..a4d53067b0 100644 --- a/api_app/services/health_checker.py +++ b/api_app/services/health_checker.py @@ -3,6 +3,7 @@ from azure.servicebus.aio import ServiceBusClient from azure.mgmt.compute.aio import ComputeManagementClient from azure.cosmos.exceptions import CosmosHttpResponseError +from azure.cosmos.aio import ContainerProxy from azure.servicebus.exceptions import ServiceBusConnectionError, ServiceBusAuthenticationError from api.dependencies.database import Database from core.config import STATE_STORE_RESOURCES_CONTAINER @@ -17,7 +18,7 @@ async def create_state_store_status() -> Tuple[StatusEnum, str]: status = StatusEnum.ok message = "" try: - container = await Database().get_container_proxy(STATE_STORE_RESOURCES_CONTAINER) + container: ContainerProxy = await Database().get_container_proxy(STATE_STORE_RESOURCES_CONTAINER) container.query_items("SELECT TOP 1 * FROM c") except exceptions.ServiceRequestError: status = StatusEnum.not_ok diff --git a/api_app/tests_ma/test_db/test_events.py b/api_app/tests_ma/test_db/test_events.py index 016067b105..009259f983 100644 --- a/api_app/tests_ma/test_db/test_events.py +++ b/api_app/tests_ma/test_db/test_events.py @@ -8,8 +8,8 @@ @patch("db.events.get_credential") @patch("db.events.CosmosDBManagementClient") -async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_credential_async_cm_mock): - get_credential_async_cm_mock.return_value = AsyncMock() +async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_credential_async_context_mock): + get_credential_async_context_mock.return_value = AsyncMock() cosmos_db_mgmt_client_mock.return_value = MagicMock() result = await events.bootstrap_database() @@ -19,8 +19,8 @@ async def test_bootstrap_database_success(cosmos_db_mgmt_client_mock, get_creden @patch("db.events.get_credential") @patch("db.events.CosmosDBManagementClient") -async def test_bootstrap_database_failure(cosmos_db_mgmt_client_mock, get_credential_async_cm_mock): - get_credential_async_cm_mock.return_value = AsyncMock() +async def test_bootstrap_database_failure(cosmos_db_mgmt_client_mock, get_credential_async_context_mock): + get_credential_async_context_mock.return_value = AsyncMock() cosmos_db_mgmt_client_mock.side_effect = AzureError("some error") result = await events.bootstrap_database() diff --git a/api_app/tests_ma/test_services/test_health_checker.py b/api_app/tests_ma/test_services/test_health_checker.py index 4d053d7e80..6b1d955f89 100644 --- a/api_app/tests_ma/test_services/test_health_checker.py +++ b/api_app/tests_ma/test_services/test_health_checker.py @@ -11,9 +11,8 @@ pytestmark = pytest.mark.asyncio -@patch("api.dependencies.database.Database.get_container_proxy", return_value=AsyncMock()) +@patch("azure.cosmos.aio.ContainerProxy.query_items", return_value=AsyncMock()) async def test_get_state_store_status_responding(_) -> None: - # get_store_key_mock.return_value = None status, message = await health_checker.create_state_store_status() assert status == StatusEnum.ok @@ -40,45 +39,45 @@ async def test_get_state_store_status_other_exception(container_proxy_mock) -> N assert message == strings.UNSPECIFIED_ERROR -@patch("core.credentials.get_credential_async_cm") +@patch("core.credentials.get_credential_async_context") @patch("services.health_checker.ServiceBusClient") -async def test_get_service_bus_status_responding(service_bus_client_mock, get_credential_async_cm) -> None: - get_credential_async_cm.return_value = AsyncMock() +async def test_get_service_bus_status_responding(service_bus_client_mock, get_credential_async_context) -> None: + get_credential_async_context.return_value = AsyncMock() service_bus_client_mock().get_queue_receiver.__aenter__.return_value = AsyncMock() - status, message = await health_checker.create_service_bus_status(get_credential_async_cm) + status, message = await health_checker.create_service_bus_status(get_credential_async_context) assert status == StatusEnum.ok assert message == "" -@patch("core.credentials.get_credential_async_cm") +@patch("core.credentials.get_credential_async_context") @patch("services.health_checker.ServiceBusClient") -async def test_get_service_bus_status_not_responding(service_bus_client_mock, get_credential_async_cm) -> None: - get_credential_async_cm.return_value = AsyncMock() +async def test_get_service_bus_status_not_responding(service_bus_client_mock, get_credential_async_context) -> None: + get_credential_async_context.return_value = AsyncMock() service_bus_client_mock.return_value = None service_bus_client_mock.side_effect = ServiceBusConnectionError(message="some message") - status, message = await health_checker.create_service_bus_status(get_credential_async_cm) + status, message = await health_checker.create_service_bus_status(get_credential_async_context) assert status == StatusEnum.not_ok assert message == strings.SERVICE_BUS_NOT_RESPONDING -@patch("core.credentials.get_credential_async_cm") +@patch("core.credentials.get_credential_async_context") @patch("services.health_checker.ServiceBusClient") -async def test_get_service_bus_status_other_exception(service_bus_client_mock, get_credential_async_cm) -> None: - get_credential_async_cm.return_value = AsyncMock() +async def test_get_service_bus_status_other_exception(service_bus_client_mock, get_credential_async_context) -> None: + get_credential_async_context.return_value = AsyncMock() service_bus_client_mock.return_value = None service_bus_client_mock.side_effect = Exception() - status, message = await health_checker.create_service_bus_status(get_credential_async_cm) + status, message = await health_checker.create_service_bus_status(get_credential_async_context) assert status == StatusEnum.not_ok assert message == strings.UNSPECIFIED_ERROR -@patch("core.credentials.get_credential_async_cm") +@patch("core.credentials.get_credential_async_context") @patch("services.health_checker.ComputeManagementClient") -async def test_get_resource_processor_status_healthy(resource_processor_client_mock, get_credential_async_cm) -> None: - get_credential_async_cm.return_value = AsyncMock() +async def test_get_resource_processor_status_healthy(resource_processor_client_mock, get_credential_async_context) -> None: + get_credential_async_context.return_value = AsyncMock() resource_processor_client_mock().virtual_machine_scale_set_vms.return_value = AsyncMock() vm_mock = MagicMock() vm_mock.instance_id = 'mocked_id' @@ -90,16 +89,16 @@ async def test_get_resource_processor_status_healthy(resource_processor_client_m awaited_mock.set_result(instance_view_mock) resource_processor_client_mock().virtual_machine_scale_set_vms.get_instance_view.return_value = awaited_mock - status, message = await health_checker.create_resource_processor_status(get_credential_async_cm) + status, message = await health_checker.create_resource_processor_status(get_credential_async_context) assert status == StatusEnum.ok assert message == "" -@patch("core.credentials.get_credential_async_cm") +@patch("core.credentials.get_credential_async_context") @patch("services.health_checker.ComputeManagementClient", return_value=MagicMock()) -async def test_get_resource_processor_status_not_healthy(resource_processor_client_mock, get_credential_async_cm) -> None: - get_credential_async_cm.return_value = AsyncMock() +async def test_get_resource_processor_status_not_healthy(resource_processor_client_mock, get_credential_async_context) -> None: + get_credential_async_context.return_value = AsyncMock() resource_processor_client_mock().virtual_machine_scale_set_vms.return_value = AsyncMock() vm_mock = MagicMock() @@ -112,19 +111,19 @@ async def test_get_resource_processor_status_not_healthy(resource_processor_clie awaited_mock.set_result(instance_view_mock) resource_processor_client_mock().virtual_machine_scale_set_vms.get_instance_view.return_value = awaited_mock - status, message = await health_checker.create_resource_processor_status(get_credential_async_cm) + status, message = await health_checker.create_resource_processor_status(get_credential_async_context) assert status == StatusEnum.not_ok assert message == strings.RESOURCE_PROCESSOR_GENERAL_ERROR_MESSAGE -@patch("core.credentials.get_credential_async_cm") +@patch("core.credentials.get_credential_async_context") @patch("services.health_checker.ComputeManagementClient") -async def test_get_resource_processor_status_other_exception(resource_processor_client_mock, get_credential_async_cm) -> None: - get_credential_async_cm.return_value = AsyncMock() +async def test_get_resource_processor_status_other_exception(resource_processor_client_mock, get_credential_async_context) -> None: + get_credential_async_context.return_value = AsyncMock() resource_processor_client_mock.return_value = None resource_processor_client_mock.side_effect = Exception() - status, message = await health_checker.create_resource_processor_status(get_credential_async_cm) + status, message = await health_checker.create_resource_processor_status(get_credential_async_context) assert status == StatusEnum.not_ok assert message == strings.UNSPECIFIED_ERROR From 90f75574494589dcaea27abb77af7d6abd6b3db8 Mon Sep 17 00:00:00 2001 From: marrobi Date: Wed, 3 Jan 2024 16:07:02 +0000 Subject: [PATCH 29/32] fix linting --- api_app/_version.py | 2 +- api_app/core/credentials.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/api_app/_version.py b/api_app/_version.py index 782e3ece69..1317d7554a 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.18.8" +__version__ = "0.18.0" diff --git a/api_app/core/credentials.py b/api_app/core/credentials.py index b05ad1bf17..c752052ac9 100644 --- a/api_app/core/credentials.py +++ b/api_app/core/credentials.py @@ -29,6 +29,7 @@ def get_credential() -> TokenCredential: exclude_powershell_credential=True ) + async def get_credential_async(managed_identity): return ( ChainedTokenCredentialASync( @@ -44,6 +45,7 @@ async def get_credential_async(managed_identity): ) ) + @asynccontextmanager async def get_credential_async_context() -> TokenCredential: """ From 18b651da30ac08c0c79b429ae9221d0dfc6d53db Mon Sep 17 00:00:00 2001 From: marrobi Date: Wed, 3 Jan 2024 16:58:52 +0000 Subject: [PATCH 30/32] fix msi handling --- api_app/core/credentials.py | 8 +++--- .../tests_ma/test_core/test_credentials.py | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) create mode 100644 api_app/tests_ma/test_core/test_credentials.py diff --git a/api_app/core/credentials.py b/api_app/core/credentials.py index c752052ac9..427f62b529 100644 --- a/api_app/core/credentials.py +++ b/api_app/core/credentials.py @@ -30,12 +30,12 @@ def get_credential() -> TokenCredential: ) -async def get_credential_async(managed_identity): +async def get_credential_async(): return ( ChainedTokenCredentialASync( - ManagedIdentityCredentialASync(client_id=managed_identity) + ManagedIdentityCredentialASync(client_id=MANAGED_IDENTITY_CLIENT_ID) ) - if managed_identity + if MANAGED_IDENTITY_CLIENT_ID else DefaultAzureCredentialASync(authority=urlparse(AAD_AUTHORITY_URL).netloc, exclude_shared_token_cache_credential=True, exclude_workload_identity_credential=True, @@ -51,6 +51,6 @@ async def get_credential_async_context() -> TokenCredential: """ Context manager which yields the default credentials. """ - credential = await get_credential_async(MANAGED_IDENTITY_CLIENT_ID) + credential = await get_credential_async() yield credential await credential.close() diff --git a/api_app/tests_ma/test_core/test_credentials.py b/api_app/tests_ma/test_core/test_credentials.py new file mode 100644 index 0000000000..3140966544 --- /dev/null +++ b/api_app/tests_ma/test_core/test_credentials.py @@ -0,0 +1,26 @@ +from unittest.mock import MagicMock, patch +from urllib.parse import urlparse +import pytest + +from azure.identity.aio import ( + DefaultAzureCredential as DefaultAzureCredentialASync, + ManagedIdentityCredential as ManagedIdentityCredentialASync, + ChainedTokenCredential as ChainedTokenCredentialASync, +) + +from core.credentials import get_credential_async + +pytestmark = pytest.mark.asyncio + + +@patch("core.credentials.MANAGED_IDENTITY_CLIENT_ID", "mocked_client_id") +async def test_get_credential_async_with_managed_identity_client_id(): + credential = await get_credential_async() + + assert isinstance(credential.credentials[0], ManagedIdentityCredentialASync) + + +async def test_get_credential_async_without_managed_identity_client_id(): + credential = await get_credential_async() + + assert isinstance(credential, DefaultAzureCredentialASync) From 2e95926621d0eaf5670b84a34523cc5c2ee782d0 Mon Sep 17 00:00:00 2001 From: marrobi Date: Wed, 3 Jan 2024 17:00:03 +0000 Subject: [PATCH 31/32] missed file. --- api_app/tests_ma/test_core/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 api_app/tests_ma/test_core/__init__.py diff --git a/api_app/tests_ma/test_core/__init__.py b/api_app/tests_ma/test_core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From c33d9db1fd79935acff3bb372ccc39b931354f06 Mon Sep 17 00:00:00 2001 From: marrobi Date: Wed, 3 Jan 2024 17:11:38 +0000 Subject: [PATCH 32/32] fix linting --- api_app/tests_ma/test_core/test_credentials.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/api_app/tests_ma/test_core/test_credentials.py b/api_app/tests_ma/test_core/test_credentials.py index 3140966544..d99e360be9 100644 --- a/api_app/tests_ma/test_core/test_credentials.py +++ b/api_app/tests_ma/test_core/test_credentials.py @@ -1,11 +1,9 @@ -from unittest.mock import MagicMock, patch -from urllib.parse import urlparse +from mock import patch import pytest from azure.identity.aio import ( DefaultAzureCredential as DefaultAzureCredentialASync, - ManagedIdentityCredential as ManagedIdentityCredentialASync, - ChainedTokenCredential as ChainedTokenCredentialASync, + ManagedIdentityCredential as ManagedIdentityCredentialASync ) from core.credentials import get_credential_async