Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the option to resolve dependencies by name #39

Merged
merged 4 commits into from
Apr 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions DIKit/Sources/Component/Component.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,51 @@ public typealias ComponentFactory = () -> Any

class Component<T>: ComponentProtocol {
let lifetime: Lifetime
let tag: String
let identifier: AnyHashable
let type: Any.Type
let componentFactory: ComponentFactory

init(lifetime: Lifetime, type: T.Type, factory: @escaping () -> T) {
init(lifetime: Lifetime, factory: @escaping () -> T) {
self.lifetime = lifetime
self.tag = String(describing: type)
self.type = type
self.identifier = ComponentIdentifier(type: T.self)
self.type = T.self
self.componentFactory = { factory() }
}

init(lifetime: Lifetime, tag: AnyHashable, factory: @escaping () -> T) {
self.lifetime = lifetime
self.identifier = ComponentIdentifier(tag: tag, type: T.self)
self.type = T.self
self.componentFactory = { factory() }
}
}

struct ComponentIdentifier: Hashable {
let tag: AnyHashable?
let type: Any.Type

func hash(into hasher: inout Hasher) {
hasher.combine(String(describing: type))
if let tag = tag {
hasher.combine(tag)
}
}

static func == (lhs: ComponentIdentifier, rhs: ComponentIdentifier) -> Bool {
lhs.type == rhs.type && lhs.tag == rhs.tag
}
}

extension ComponentIdentifier {
init(type: Any.Type) {
self.type = type
self.tag = nil
}
}

public protocol ComponentProtocol {
var lifetime: Lifetime { get }
var tag: String { get }
var identifier: AnyHashable { get }
var componentFactory: ComponentFactory { get }
var type: Any.Type { get }
}
27 changes: 16 additions & 11 deletions DIKit/Sources/Container/DependencyContainer+Register.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,31 @@ extension DependencyContainer {
/// Registers a `Component`.
///
/// - Parameters:
/// - scope: The *scope* of the `Component`, defaults to `Lifetime.singleton`.
/// - lifetime: The *scope* of the `Component`, defaults to `Lifetime.singleton`.
/// - factory: The *factory* for the initialization of the `Component`.
public func register<T>(lifetime: Lifetime = .singleton, _ factory: @escaping () -> T) {
precondition(!bootstrapped, "After boostrap no more components can be registered.")
threadSafe {
let component = Component(lifetime: lifetime, type: T.self, factory: factory)
guard self.componentStack[component.tag] == nil else {
fatalError("A component can only be registered once.")
}
self.componentStack[component.tag] = component
}
let component = Component(lifetime: lifetime, factory: factory)
register(component)
}

/// Registers a `Component`
///
/// - Parameters:
/// - lifetime: The *scope* of the `Component`, defaults to `Lifetime.singleton`.
/// - tag: A *tag* for the `Component` used to identify it.
/// - factory: The *factory* for the initialization of the `Component`.
public func register<T>(lifetime: Lifetime = .singleton, tag: AnyHashable, _ factory: @escaping () -> T) {
let component = Component(lifetime: lifetime, tag: tag, factory: factory)
register(component)
}

public func register(_ component: ComponentProtocol) {
precondition(!bootstrapped, "After boostrap no more components can be registered.")
threadSafe {
guard self.componentStack[component.tag] == nil else {
guard self.componentStack[component.identifier] == nil else {
fatalError("A component can only be registered once.")
}
self.componentStack[component.tag] = component
self.componentStack[component.identifier] = component
}
}
}
24 changes: 14 additions & 10 deletions DIKit/Sources/Container/DependencyContainer+Resolve.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,21 @@
extension DependencyContainer {
/// Resolves nil safe a `Component<T>`.
///
/// - Parameter tag: An optional *tag* to identify the Component. `nil` per default.
/// - Returns: The resolved `Optional<Component<T>>`.
func _resolve<T>() -> T? {
let tag = String(describing: T.self)
guard let foundComponent = self.componentStack[tag] else {
func _resolve<T>(tag: AnyHashable? = nil) -> T? {
let identifier = ComponentIdentifier(tag: tag, type: T.self)
guard let foundComponent = self.componentStack[identifier] else {
return nil
}
if foundComponent.lifetime == .factory {
return foundComponent.componentFactory() as? T
}
if let instanceOfComponent = self.instanceStack[tag] as? T {
if let instanceOfComponent = self.instanceStack[identifier] as? T {
return instanceOfComponent
}
let instance = foundComponent.componentFactory() as! T
self.instanceStack[tag] = instance
self.instanceStack[identifier] = instance
return instance
}

Expand All @@ -32,20 +33,23 @@ extension DependencyContainer {
///
/// - Parameters:
/// - type: The generic *type* of the `Component`.
/// - tag: An optional *tag* to identify the Component. `nil` per default.
///
/// - Returns: `Bool` whether `Component<T>` is resolvable or not.
func resolvable<T>(type: T.Type) -> Bool {
let tag = String(describing: type)
return self.componentStack[tag] != nil
func resolvable<T>(type: T.Type, tag: AnyHashable? = nil) -> Bool {
let identifier = ComponentIdentifier(tag: tag, type: T.self)
return self.componentStack[identifier] != nil
}

/// Resolves a `Component<T>`.
/// Implicitly assumes that the `Component` can be resolved.
/// Throws a fatalError if the `Component` is not registered.
///
/// - Parameter tag: An optional *tag* to identify the Component. `nil` per default.
///
/// - Returns: The resolved `Component<T>`.
public func resolve<T>() -> T {
if let t: T = _resolve() {
public func resolve<T>(tag: AnyHashable? = nil) -> T {
if let t: T = _resolve(tag: tag) {
return t
}
fatalError("Component `\(String(describing: T.self))` could not be resolved.")
Expand Down
4 changes: 2 additions & 2 deletions DIKit/Sources/Container/DependencyContainer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import Foundation
public final class DependencyContainer {
// MARK: - Typealiases
public typealias BootstrapBlock = (DependencyContainer) -> Void
internal typealias ComponentStack = [String: ComponentProtocol]
internal typealias InstanceStack = [String: Any]
internal typealias ComponentStack = [AnyHashable: ComponentProtocol]
benjohnde marked this conversation as resolved.
Show resolved Hide resolved
internal typealias InstanceStack = [AnyHashable: Any]

// MARK: - Properties
internal var bootstrapped = false
Expand Down
12 changes: 12 additions & 0 deletions DIKit/Sources/DIKit+Inject.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ public enum LazyInject<Component> {
self = .unresolved({ resolve() })
}

public init(tag: AnyHashable? = nil) {
self = .unresolved({ resolve(tag: tag) })
}

public var wrappedValue: Component {
mutating get {
switch self {
Expand All @@ -41,6 +45,10 @@ public struct Inject<Component> {
public init() {
self.wrappedValue = resolve()
}

public init(tag: AnyHashable? = nil) {
self.wrappedValue = resolve(tag: tag)
}
}

/// A property wrapper (SE-0258) to make a `Optional<Component>` injectable
Expand All @@ -54,6 +62,10 @@ public enum OptionalInject<Component> {
self = .unresolved({ resolveOptional() })
}

public init(tag: AnyHashable? = nil) {
self = .unresolved({ resolveOptional(tag: tag) })
}

public var wrappedValue: Component? {
mutating get {
switch self {
Expand Down
10 changes: 6 additions & 4 deletions DIKit/Sources/DIKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

/// Resolves given `Component<T>`.
///
/// - Parameter tag: An optional *tag* to identify the Component. `nil` per default.
/// - Returns: The resolved `Component<T>`.
public func resolve<T>() -> T { DependencyContainer.shared.resolve() }
public func resolve<T>(tag: AnyHashable? = nil) -> T { DependencyContainer.shared.resolve(tag: tag) }

/// Resolves nil safe given `Component<T>`.
///
/// - Parameter tag: An optional *tag* to identify the Component. `nil` per default.
/// - Returns: The resolved `Optional<Component<T>>`.
public func resolveOptional<T>() -> T? {
guard DependencyContainer.shared.resolvable(type: T.self) else { return nil }
return DependencyContainer.shared._resolve()
public func resolveOptional<T>(tag: AnyHashable? = nil) -> T? {
guard DependencyContainer.shared.resolvable(type: T.self, tag: tag) else { return nil }
return DependencyContainer.shared._resolve(tag: tag)
}
14 changes: 13 additions & 1 deletion DIKit/Sources/DIKitDSL.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,25 @@ public func modules(@ModulesBuilder makeChildren: () -> [DependencyContainer]) -
}

public func resolvable<T>(lifetime: Lifetime = .singleton, _ factory: @escaping () -> T) -> ComponentProtocol {
Component(lifetime: lifetime, type: T.self, factory: factory) as ComponentProtocol
Component(lifetime: lifetime, factory: factory) as ComponentProtocol
}

public func resolvable<T>(lifetime: Lifetime = .singleton, tag: AnyHashable, _ factory: @escaping () -> T) -> ComponentProtocol {
Component(lifetime: lifetime, tag: tag, factory: factory) as ComponentProtocol
}

public func factory<T>(factory: @escaping () -> T) -> [ComponentProtocol] {
[resolvable(lifetime: .factory, factory)]
}

public func factory<T>(tag: AnyHashable, factory: @escaping () -> T) -> [ComponentProtocol] {
[resolvable(lifetime: .factory, tag: tag, factory)]
}

public func single<T>(factory: @escaping () -> T) -> [ComponentProtocol] {
[resolvable(lifetime: .singleton, factory)]
}

public func single<T>(tag: AnyHashable, factory: @escaping () -> T) -> [ComponentProtocol] {
[resolvable(lifetime: .singleton, tag: tag, factory)]
}
60 changes: 51 additions & 9 deletions DIKit/Tests/DIKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ class DIKitTests: XCTestCase {
let dependencyContainer = DependencyContainer { (c: DependencyContainer) in
c.register { ComponentA() }
c.register { ComponentB() }
c.register(tag: "tag") { ComponentB() }
}

guard let componentA = dependencyContainer.componentStack.index(forKey: "ComponentA") else {
let componentAIdentifier = ComponentIdentifier(type: ComponentA.self)
guard let componentA = dependencyContainer.componentStack.index(forKey: componentAIdentifier) else {
return XCTFail("ComponentStack does not contain `ComponentA`.")
}
let componentProtocolA = dependencyContainer.componentStack[componentA].value
Expand All @@ -37,14 +39,25 @@ class DIKitTests: XCTestCase {
XCTAssertTrue(instanceA is ComponentA)
XCTAssertFalse(instanceA is ComponentB)

guard let componentB = dependencyContainer.componentStack.index(forKey: "ComponentB") else {
let componentBIdentifier = ComponentIdentifier(type: ComponentB.self)
guard let componentB = dependencyContainer.componentStack.index(forKey: componentBIdentifier) else {
return XCTFail("ComponentStack does not contain `ComponentB`.")
}
let componentProtocolB = dependencyContainer.componentStack[componentB].value
XCTAssertEqual(componentProtocolB.lifetime, .singleton)
let instanceB = componentProtocolB.componentFactory()
XCTAssertTrue(instanceB is ComponentB)
XCTAssertFalse(instanceB is ComponentA)

let taggedComponentBIdentifier = ComponentIdentifier(tag: "tag", type: ComponentB.self)
guard let taggedComponentB = dependencyContainer.componentStack.index(forKey: taggedComponentBIdentifier) else {
return XCTFail("ComponentStack does not contain `ComponentB`.")
}
let taggedComponentProtocolB = dependencyContainer.componentStack[taggedComponentB].value
XCTAssertEqual(taggedComponentProtocolB.lifetime, .singleton)
let taggedInstanceB = taggedComponentProtocolB.componentFactory()
XCTAssertTrue(taggedInstanceB is ComponentB)
XCTAssertFalse(taggedInstanceB is ComponentA)
}

func testDependencyContainerResolve() {
Expand All @@ -58,6 +71,17 @@ class DIKitTests: XCTestCase {
XCTAssertNotNil(componentA)
}

func testDependencyContainerTaggedResolve() {
class ComponentA {}

let dependencyContainer = DependencyContainer { (c: DependencyContainer) in
c.register(tag: "tag") { ComponentA() }
}

let componentA: ComponentA = dependencyContainer.resolve(tag: "tag")
XCTAssertNotNil(componentA)
}

func testDependencyContainerDerive() {
struct ComponentA {}
struct ComponentB {}
Expand All @@ -72,10 +96,14 @@ class DIKitTests: XCTestCase {
let dependencyContainerC = DependencyContainer { (c: DependencyContainer) in
c.register { ComponentC() }
}
let dependencyContainerD = DependencyContainer { (c: DependencyContainer) in
c.register(tag: "tag") { ComponentC() }
}

let dependencyContainer = DependencyContainer.derive(from: dependencyContainerA,
dependencyContainerB,
dependencyContainerC)
dependencyContainerC,
dependencyContainerD)

let componentA: ComponentA = dependencyContainer.resolve()
XCTAssertNotNil(componentA)
Expand All @@ -85,6 +113,9 @@ class DIKitTests: XCTestCase {

let componentC: ComponentC = dependencyContainer.resolve()
XCTAssertNotNil(componentC)

let taggedComponentC: ComponentC = dependencyContainer.resolve(tag: "tag")
XCTAssertNotNil(taggedComponentC)
}

func testDependencyContainerDeriveDSL() {
Expand All @@ -101,8 +132,11 @@ class DIKitTests: XCTestCase {
let dependencyContainerC = module {
single { ComponentC() }
}
let dependencyContainerD = module {
single(tag: "tag") { ComponentC() }
}

let dependencyContainer = modules { dependencyContainerA; dependencyContainerB; dependencyContainerC }
let dependencyContainer = modules { dependencyContainerA; dependencyContainerB; dependencyContainerC; dependencyContainerD }

let componentA: ComponentA = dependencyContainer.resolve()
XCTAssertNotNil(componentA)
Expand All @@ -112,6 +146,9 @@ class DIKitTests: XCTestCase {

let componentC: ComponentC = dependencyContainer.resolve()
XCTAssertNotNil(componentC)

let taggedComponentC: ComponentC = dependencyContainer.resolve(tag: "tag")
XCTAssertNotNil(taggedComponentC)
}

func testFactoryOfComponents() {
Expand Down Expand Up @@ -350,6 +387,7 @@ class DIKitTests: XCTestCase {

let dependencyContainer = DependencyContainer { (c: DependencyContainer) in
c.register(lifetime: .singleton) { ComponentA() }
c.register(lifetime: .singleton, tag: "tag") { ComponentA() }
}

let componentAinstanceA: ComponentA = dependencyContainer.resolve()
Expand All @@ -358,12 +396,16 @@ class DIKitTests: XCTestCase {
let componentAinstanceB: ComponentA = dependencyContainer.resolve()
XCTAssertNotNil(componentAinstanceB)

let componentAinstanceAobjectIdA = ObjectIdentifier(componentAinstanceA)
let componentAinstanceAobjectIdB = ObjectIdentifier(componentAinstanceA)
let componentAinstanceBobjectId = ObjectIdentifier(componentAinstanceB)
let taggedComponentAinstanceA: ComponentA = dependencyContainer.resolve(tag: "tag")
XCTAssertNotNil(taggedComponentAinstanceA)

XCTAssertEqual(componentAinstanceAobjectIdA, componentAinstanceAobjectIdB)
XCTAssertEqual(componentAinstanceAobjectIdA, componentAinstanceBobjectId)
let taggedComponentAinstanceB: ComponentA = dependencyContainer.resolve(tag: "tag")
XCTAssertNotNil(taggedComponentAinstanceB)

XCTAssertTrue(componentAinstanceA === componentAinstanceB)
XCTAssertTrue(taggedComponentAinstanceA === taggedComponentAinstanceB)
XCTAssertTrue(componentAinstanceA !== taggedComponentAinstanceA)
XCTAssertTrue(componentAinstanceA !== taggedComponentAinstanceB)
}

func testSingletonLifetimeOfComponentsDSL() {
Expand Down