diff --git a/lms/services/__init__.py b/lms/services/__init__.py index 98e9b9a47c..41813936ec 100644 --- a/lms/services/__init__.py +++ b/lms/services/__init__.py @@ -106,7 +106,9 @@ def includeme(config): # noqa: PLR0915 config.register_service_factory( "lms.services.grouping.service_factory", name="grouping" ) - config.register_service_factory("lms.services.file.factory", name="file") + config.register_service_factory( + "lms.services.file.file_service_factory", name="file" + ) config.register_service_factory( "lms.services.jstor.service_factory", iface=JSTORService ) diff --git a/lms/services/canvas_api/factory.py b/lms/services/canvas_api/factory.py index 129d7eaa9d..97ede57a82 100644 --- a/lms/services/canvas_api/factory.py +++ b/lms/services/canvas_api/factory.py @@ -3,16 +3,37 @@ from lms.services.canvas_api._basic import BasicClient from lms.services.canvas_api._pages import CanvasPagesClient from lms.services.canvas_api.client import CanvasAPIClient +from lms.services.file import file_service_factory +from lms.services.oauth2_token import oauth2_token_service_factory -def canvas_api_client_factory(_context, request): +def canvas_api_client_factory( + _context, request, application_instance=None, user_id=None +): """ Get a CanvasAPIClient from a pyramid request. :param request: Pyramid request object :return: An instance of CanvasAPIClient """ - application_instance = request.lti_user.application_instance + if application_instance and user_id: + oauth2_token_service = oauth2_token_service_factory( + _context, + request, + application_instance=application_instance, + user_id=user_id, + ) + file_service = file_service_factory(_context, request, application_instance) + + else: + oauth2_token_service = request.find_service(name="oauth2_token") + file_service = request.find_service(name="file") + + if not application_instance: + application_instance = request.lti_user.application_instance + + if not user_id: + user_id = request.lti_user.user_id developer_secret = application_instance.decrypted_developer_secret( request.find_service(AESService) @@ -22,13 +43,11 @@ def canvas_api_client_factory(_context, request): authenticated_api = AuthenticatedClient( basic_client=basic_client, - oauth2_token_service=request.find_service(name="oauth2_token"), + oauth2_token_service=oauth2_token_service, client_id=application_instance.developer_key, client_secret=developer_secret, redirect_uri=request.route_url("canvas_api.oauth.callback"), ) - file_service = request.find_service(name="file") - return CanvasAPIClient( authenticated_api, file_service=file_service, diff --git a/lms/services/file.py b/lms/services/file.py index dad6a8663a..fc89de2025 100644 --- a/lms/services/file.py +++ b/lms/services/file.py @@ -90,7 +90,8 @@ def _file_search_query( # noqa: PLR0913 return query -def factory(_context, request): - return FileService( - application_instance=request.lti_user.application_instance, db=request.db - ) +def file_service_factory(_context, request, application_instance=None): + if application_instance is None: + application_instance = request.lti_user.application_instance + + return FileService(application_instance=application_instance, db=request.db) diff --git a/tests/unit/lms/services/file_test.py b/tests/unit/lms/services/file_test.py index 463891fd0e..aa80147a1e 100644 --- a/tests/unit/lms/services/file_test.py +++ b/tests/unit/lms/services/file_test.py @@ -3,7 +3,7 @@ import pytest from lms.models import File -from lms.services.file import FileService, factory +from lms.services.file import FileService, file_service_factory from tests import factories @@ -140,7 +140,14 @@ def file(self, application_instance): @pytest.mark.usefixtures("application_instance_service") class TestFactory: - def test_it(self, pyramid_request): - file_service = factory(sentinel.context, pyramid_request) + def test_it(self, pyramid_request, FileService, application_instance, db_session): + file_service = file_service_factory(sentinel.context, pyramid_request) - assert isinstance(file_service, FileService) + FileService.assert_called_once_with( + application_instance=application_instance, db=db_session + ) + assert file_service == FileService.return_value + + @pytest.fixture + def FileService(self, patch): + return patch("lms.services.file.FileService")