From 5f72462574eb590725331aeda2a57f8aa408cfc8 Mon Sep 17 00:00:00 2001 From: marrobi Date: Tue, 24 Oct 2023 18:30:08 +0000 Subject: [PATCH] Add tests, change rsa library --- api_app/services/aad_authentication.py | 12 +++- .../test_services/test_aad_access_service.py | 62 ++++++++++++++++++- 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/api_app/services/aad_authentication.py b/api_app/services/aad_authentication.py index cffdd77259..61ff23f1fb 100644 --- a/api_app/services/aad_authentication.py +++ b/api_app/services/aad_authentication.py @@ -5,7 +5,6 @@ from typing import List, Optional import jwt import requests -import rsa from fastapi import Request, HTTPException, status from msal import ConfidentialClientApplication @@ -19,6 +18,10 @@ from api.dependencies.database import get_db_client_from_request from db.repositories.workspaces import WorkspaceRepository +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization + MICROSOFT_GRAPH_URL = config.MICROSOFT_GRAPH_URL.strip("/") @@ -179,9 +182,12 @@ def _get_token_key(self, key_id: str) -> str: for key in keys['keys']: n = int.from_bytes(base64.urlsafe_b64decode(self._ensure_b64padding(key['n'])), "big") e = int.from_bytes(base64.urlsafe_b64decode(self._ensure_b64padding(key['e'])), "big") - pub_key = rsa.PublicKey(n, e) + pub_key = rsa.RSAPublicNumbers(e, n).public_key(default_backend()) # Cache the PEM formatted public key. - AzureADAuthorization._jwt_keys[key['kid']] = pub_key.save_pkcs1() + AzureADAuthorization._jwt_keys[key['kid']] = pub_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.PKCS1 + ) return AzureADAuthorization._jwt_keys[key_id] diff --git a/api_app/tests_ma/test_services/test_aad_access_service.py b/api_app/tests_ma/test_services/test_aad_access_service.py index 668de31562..eeb7754106 100644 --- a/api_app/tests_ma/test_services/test_aad_access_service.py +++ b/api_app/tests_ma/test_services/test_aad_access_service.py @@ -1,5 +1,5 @@ import pytest -from mock import patch +from mock import call, patch from models.domain.authentication import User, RoleAssignment from models.domain.workspace import Workspace, WorkspaceRole @@ -554,6 +554,66 @@ def test_get_workspace_role_assignment_details_with_groups_and_users_assigned_re assert "test_user1@email.com" in role_assignment_details["WorkspaceOwner"] +@patch("services.aad_authentication.AzureADAuthorization._get_auth_header") +@patch("services.aad_authentication.AzureADAuthorization._get_batch_users_by_role_assignments_body") +@patch("requests.post") +def test_get_user_emails_with_batch_of_more_than_20_requests(mock_graph_post, mock_get_batch_users_by_role_assignments_body, mock_headers): + # Arrange + access_service = AzureADAuthorization() + roles_graph_data = [{"id": "role1"}, {"id": "role2"}] + msgraph_token = "token" + batch_endpoint = access_service._get_batch_endpoint() + + # mock the response of _get_auth_header + headers = {"Authorization": f"Bearer {msgraph_token}"} + mock_headers.return_value = headers + headers["Content-type"] = "application/json" + + # mock the response of the get batch request for 30 users + batch_request_body_first_20 = { + "requests": [ + {"id": f"{i}", "method": "GET", "url": f"/users/{i}"} for i in range(20) + ] + } + + batch_request_body_last_10 = { + "requests": [ + {"id": f"{i}", "method": "GET", "url": f"/users/{i}"} for i in range(20, 30) + ] + } + + batch_request_body = { + "requests": [ + {"id": f"{i}", "method": "GET", "url": f"/users/{i}"} for i in range(30) + ] + } + + mock_get_batch_users_by_role_assignments_body.return_value = batch_request_body + + # Mock the response of the post request + mock_graph_post_response = {"responses": [{"id": "user1"}, {"id": "user2"}]} + mock_graph_post.return_value.json.return_value = mock_graph_post_response + + # Act + users_graph_data = access_service._get_user_emails(roles_graph_data, msgraph_token) + + # Assert + assert len(users_graph_data["responses"]) == 4 + calls = [ + call( + f"{batch_endpoint}", + json=batch_request_body_first_20, + headers=headers + ), + call( + f"{batch_endpoint}", + json=batch_request_body_last_10, + headers=headers + ) + ] + mock_graph_post.assert_has_calls(calls, any_order=True) + + def get_mock_batch_response(user_principals, group_principals): response_body = {"responses": []} for user_principal in user_principals: