Skip to content

Commit

Permalink
Merge pull request #12 from BradLewis/develop
Browse files Browse the repository at this point in the history
Adding optional arguments and better error messaging.
  • Loading branch information
BradLewis authored Oct 21, 2020
2 parents a6c6081 + 959022b commit e8d0963
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 9 deletions.
2 changes: 1 addition & 1 deletion examples/flask/flask_inject_create_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, db_service: IDatabaseService):

def get_by_id(self, id: int):
query = "SELECT * FROM users WHERE id = ?"
result = self._db_serive.run_query(query, id)
result = self._db_service.run_query(query, id)
return {"name": "test name", "id": id}


Expand Down
52 changes: 44 additions & 8 deletions simple_injection/service_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def __init__(
self.implementations: List[Union[Type[T], T]] = list()


class ServiceResolutionError(Exception):
def __init__(self, message, errors, service):
super().__init__(message, errors)
self.service = service


class ServiceCollection:
def __init__(self):
self._service_collection: Dict[Type[T], _ContainerService] = dict()
Expand Down Expand Up @@ -200,14 +206,29 @@ def resolve(self, service_to_resolve: Type[T]) -> T:
Returns:
T: An instance of the resolved service.
"""
container_service = self._service_collection[service_to_resolve]
if container_service.multiple_implementations:
return self._resolve_multiple(container_service)
if container_service.service_lifetime == ServiceLifetime.INSTANCE:
return self._resolve_instance(container_service)
if container_service.service_lifetime == ServiceLifetime.SINGLETON:
return self._resolve_singleton(container_service)
return self._resolve_annotations(container_service)
if self._is_optional(service_to_resolve):
return self._handle_optional(service_to_resolve)
if service_to_resolve not in self._service_collection:
raise ValueError(
f"Service {service_to_resolve} not found in the collection. Ensure {service_to_resolve} has been added to the collection."
)
try:
container_service = self._service_collection[service_to_resolve]
if container_service.multiple_implementations:
return self._resolve_multiple(container_service)
if container_service.service_lifetime == ServiceLifetime.INSTANCE:
return self._resolve_instance(container_service)
if container_service.service_lifetime == ServiceLifetime.SINGLETON:
return self._resolve_singleton(container_service)
return self._resolve_annotations(container_service)
except ServiceResolutionError as e:
raise e
except Exception as e:
raise ServiceResolutionError(
f"Exception occurred trying to resolve {service_to_resolve}",
e,
service_to_resolve,
)

def _resolve_multiple(self, container_service: _ContainerService):
services_to_resolve = container_service.implementations
Expand Down Expand Up @@ -249,6 +270,13 @@ def _resolve_args(self, container_service: _ContainerService):
args.append(arg)
return container_service.service_implementation(*args)

def _handle_optional(self, service_to_resolve: Type[T]):
to_resolve = service_to_resolve.__args__[0]
if to_resolve in self._service_collection:
return self.resolve(to_resolve)
else:
return None

def _create_list_service(self, service_to_add):
service = self._service_collection[service_to_add]
self.add(
Expand All @@ -262,3 +290,11 @@ def _create_list_service(self, service_to_add):
self._service_collection[List[service_to_add]].implementations.append(
service.service_implementation
)

def _is_optional(self, service_to_resolve: Type[T]):
if not hasattr(service_to_resolve, "__args__"):
return False
args = service_to_resolve.__args__
if len(args) != 2:
return False
return args[-1] == type(None)
26 changes: 26 additions & 0 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from simple_injection.service_collection import ServiceResolutionError
import pytest
from tests.classes import A, B, C
from simple_injection import ServiceCollection


class D:
def __init__(self, b: B):
self._b = b


def test_resolution_error():
collection = ServiceCollection()
collection.add_transient(B)

with pytest.raises(ServiceResolutionError):
collection.resolve(B)


def test_nested_resolution_error():
collection = ServiceCollection()
collection.add_transient(D)
collection.add_transient(B)

with pytest.raises(ServiceResolutionError):
collection.resolve(D)
48 changes: 48 additions & 0 deletions tests/test_optional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from simple_injection import ServiceCollection
from typing import Optional
from tests.classes import A, B, C


class HasOptional:
def __init__(self, a: A, b: Optional[B] = None):
self._a = a
self._b = b


class HasOptionalWithDeps:
def __init__(self, a: A, c: Optional[C] = None):
self._a = a
self._c = c


def test_optional_declared():
collection = ServiceCollection()
collection.add_transient(A)
collection.add_transient(B)
collection.add_transient(HasOptional)

has_optional = collection.resolve(HasOptional)
assert isinstance(has_optional._a, A)
assert isinstance(has_optional._b, B)


def test_optional_not_declared():
collection = ServiceCollection()
collection.add_transient(A)
collection.add_transient(HasOptional)

has_optional = collection.resolve(HasOptional)
assert isinstance(has_optional._a, A)
assert has_optional._b is None


def test_optional_with_dependencies():
collection = ServiceCollection()
collection.add_transient(A)
collection.add_transient(B)
collection.add_transient(C)
collection.add_transient(HasOptionalWithDeps)

has_optional_with_deps = collection.resolve(HasOptionalWithDeps)
assert isinstance(has_optional_with_deps._a, A)
assert isinstance(has_optional_with_deps._c, C)

0 comments on commit e8d0963

Please sign in to comment.