Skip to content

Commit

Permalink
Add tests, change rsa library
Browse files Browse the repository at this point in the history
  • Loading branch information
marrobi committed Oct 24, 2023
1 parent a3509c8 commit 5f72462
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 4 deletions.
12 changes: 9 additions & 3 deletions api_app/services/aad_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import List, Optional
import jwt
import requests
import rsa

from fastapi import Request, HTTPException, status
from msal import ConfidentialClientApplication
Expand All @@ -19,6 +18,10 @@
from api.dependencies.database import get_db_client_from_request
from db.repositories.workspaces import WorkspaceRepository

from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

MICROSOFT_GRAPH_URL = config.MICROSOFT_GRAPH_URL.strip("/")


Expand Down Expand Up @@ -179,9 +182,12 @@ def _get_token_key(self, key_id: str) -> str:
for key in keys['keys']:
n = int.from_bytes(base64.urlsafe_b64decode(self._ensure_b64padding(key['n'])), "big")
e = int.from_bytes(base64.urlsafe_b64decode(self._ensure_b64padding(key['e'])), "big")
pub_key = rsa.PublicKey(n, e)
pub_key = rsa.RSAPublicNumbers(e, n).public_key(default_backend())
# Cache the PEM formatted public key.
AzureADAuthorization._jwt_keys[key['kid']] = pub_key.save_pkcs1()
AzureADAuthorization._jwt_keys[key['kid']] = pub_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.PKCS1
)

return AzureADAuthorization._jwt_keys[key_id]

Expand Down
62 changes: 61 additions & 1 deletion api_app/tests_ma/test_services/test_aad_access_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from mock import patch
from mock import call, patch

from models.domain.authentication import User, RoleAssignment
from models.domain.workspace import Workspace, WorkspaceRole
Expand Down Expand Up @@ -554,6 +554,66 @@ def test_get_workspace_role_assignment_details_with_groups_and_users_assigned_re
assert "[email protected]" in role_assignment_details["WorkspaceOwner"]


@patch("services.aad_authentication.AzureADAuthorization._get_auth_header")
@patch("services.aad_authentication.AzureADAuthorization._get_batch_users_by_role_assignments_body")
@patch("requests.post")
def test_get_user_emails_with_batch_of_more_than_20_requests(mock_graph_post, mock_get_batch_users_by_role_assignments_body, mock_headers):
# Arrange
access_service = AzureADAuthorization()
roles_graph_data = [{"id": "role1"}, {"id": "role2"}]
msgraph_token = "token"
batch_endpoint = access_service._get_batch_endpoint()

# mock the response of _get_auth_header
headers = {"Authorization": f"Bearer {msgraph_token}"}
mock_headers.return_value = headers
headers["Content-type"] = "application/json"

# mock the response of the get batch request for 30 users
batch_request_body_first_20 = {
"requests": [
{"id": f"{i}", "method": "GET", "url": f"/users/{i}"} for i in range(20)
]
}

batch_request_body_last_10 = {
"requests": [
{"id": f"{i}", "method": "GET", "url": f"/users/{i}"} for i in range(20, 30)
]
}

batch_request_body = {
"requests": [
{"id": f"{i}", "method": "GET", "url": f"/users/{i}"} for i in range(30)
]
}

mock_get_batch_users_by_role_assignments_body.return_value = batch_request_body

# Mock the response of the post request
mock_graph_post_response = {"responses": [{"id": "user1"}, {"id": "user2"}]}
mock_graph_post.return_value.json.return_value = mock_graph_post_response

# Act
users_graph_data = access_service._get_user_emails(roles_graph_data, msgraph_token)

# Assert
assert len(users_graph_data["responses"]) == 4
calls = [
call(
f"{batch_endpoint}",
json=batch_request_body_first_20,
headers=headers
),
call(
f"{batch_endpoint}",
json=batch_request_body_last_10,
headers=headers
)
]
mock_graph_post.assert_has_calls(calls, any_order=True)


def get_mock_batch_response(user_principals, group_principals):
response_body = {"responses": []}
for user_principal in user_principals:
Expand Down

0 comments on commit 5f72462

Please sign in to comment.