diff --git a/examples/flask/flask_inject_create_app.py b/examples/flask/flask_inject_create_app.py index 24cee59..e053bb0 100644 --- a/examples/flask/flask_inject_create_app.py +++ b/examples/flask/flask_inject_create_app.py @@ -27,7 +27,7 @@ def __init__(self, db_service: IDatabaseService): def get_by_id(self, id: int): query = "SELECT * FROM users WHERE id = ?" - result = self._db_serive.run_query(query, id) + result = self._db_service.run_query(query, id) return {"name": "test name", "id": id} diff --git a/simple_injection/service_collection.py b/simple_injection/service_collection.py index d4f3249..32f99e9 100644 --- a/simple_injection/service_collection.py +++ b/simple_injection/service_collection.py @@ -59,6 +59,12 @@ def __init__( self.implementations: List[Union[Type[T], T]] = list() +class ServiceResolutionError(Exception): + def __init__(self, message, errors, service): + super().__init__(message, errors) + self.service = service + + class ServiceCollection: def __init__(self): self._service_collection: Dict[Type[T], _ContainerService] = dict() @@ -200,14 +206,29 @@ def resolve(self, service_to_resolve: Type[T]) -> T: Returns: T: An instance of the resolved service. """ - container_service = self._service_collection[service_to_resolve] - if container_service.multiple_implementations: - return self._resolve_multiple(container_service) - if container_service.service_lifetime == ServiceLifetime.INSTANCE: - return self._resolve_instance(container_service) - if container_service.service_lifetime == ServiceLifetime.SINGLETON: - return self._resolve_singleton(container_service) - return self._resolve_annotations(container_service) + if self._is_optional(service_to_resolve): + return self._handle_optional(service_to_resolve) + if service_to_resolve not in self._service_collection: + raise ValueError( + f"Service {service_to_resolve} not found in the collection. Ensure {service_to_resolve} has been added to the collection." + ) + try: + container_service = self._service_collection[service_to_resolve] + if container_service.multiple_implementations: + return self._resolve_multiple(container_service) + if container_service.service_lifetime == ServiceLifetime.INSTANCE: + return self._resolve_instance(container_service) + if container_service.service_lifetime == ServiceLifetime.SINGLETON: + return self._resolve_singleton(container_service) + return self._resolve_annotations(container_service) + except ServiceResolutionError as e: + raise e + except Exception as e: + raise ServiceResolutionError( + f"Exception occurred trying to resolve {service_to_resolve}", + e, + service_to_resolve, + ) def _resolve_multiple(self, container_service: _ContainerService): services_to_resolve = container_service.implementations @@ -249,6 +270,13 @@ def _resolve_args(self, container_service: _ContainerService): args.append(arg) return container_service.service_implementation(*args) + def _handle_optional(self, service_to_resolve: Type[T]): + to_resolve = service_to_resolve.__args__[0] + if to_resolve in self._service_collection: + return self.resolve(to_resolve) + else: + return None + def _create_list_service(self, service_to_add): service = self._service_collection[service_to_add] self.add( @@ -262,3 +290,11 @@ def _create_list_service(self, service_to_add): self._service_collection[List[service_to_add]].implementations.append( service.service_implementation ) + + def _is_optional(self, service_to_resolve: Type[T]): + if not hasattr(service_to_resolve, "__args__"): + return False + args = service_to_resolve.__args__ + if len(args) != 2: + return False + return args[-1] == type(None) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 0000000..6be4d3f --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,26 @@ +from simple_injection.service_collection import ServiceResolutionError +import pytest +from tests.classes import A, B, C +from simple_injection import ServiceCollection + + +class D: + def __init__(self, b: B): + self._b = b + + +def test_resolution_error(): + collection = ServiceCollection() + collection.add_transient(B) + + with pytest.raises(ServiceResolutionError): + collection.resolve(B) + + +def test_nested_resolution_error(): + collection = ServiceCollection() + collection.add_transient(D) + collection.add_transient(B) + + with pytest.raises(ServiceResolutionError): + collection.resolve(D) diff --git a/tests/test_optional.py b/tests/test_optional.py new file mode 100644 index 0000000..0026358 --- /dev/null +++ b/tests/test_optional.py @@ -0,0 +1,48 @@ +from simple_injection import ServiceCollection +from typing import Optional +from tests.classes import A, B, C + + +class HasOptional: + def __init__(self, a: A, b: Optional[B] = None): + self._a = a + self._b = b + + +class HasOptionalWithDeps: + def __init__(self, a: A, c: Optional[C] = None): + self._a = a + self._c = c + + +def test_optional_declared(): + collection = ServiceCollection() + collection.add_transient(A) + collection.add_transient(B) + collection.add_transient(HasOptional) + + has_optional = collection.resolve(HasOptional) + assert isinstance(has_optional._a, A) + assert isinstance(has_optional._b, B) + + +def test_optional_not_declared(): + collection = ServiceCollection() + collection.add_transient(A) + collection.add_transient(HasOptional) + + has_optional = collection.resolve(HasOptional) + assert isinstance(has_optional._a, A) + assert has_optional._b is None + + +def test_optional_with_dependencies(): + collection = ServiceCollection() + collection.add_transient(A) + collection.add_transient(B) + collection.add_transient(C) + collection.add_transient(HasOptionalWithDeps) + + has_optional_with_deps = collection.resolve(HasOptionalWithDeps) + assert isinstance(has_optional_with_deps._a, A) + assert isinstance(has_optional_with_deps._c, C)