diff --git a/README.md b/README.md index 1d14727..5bd0ee1 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ shared/exclusive mode semantics. `MultiLock.locked` can be in one of three states: -1. `None` - not locked; +1. `MultiLockType.NONE` - not acquired; 2. `MultiLockType.SHARED` - acquired one or more times in shared mode; 3. `MultiLockType.EXCLUSIVE` - acquired one time in exclusive mode. diff --git a/src/asyncio_multilock/__init__.py b/src/asyncio_multilock/__init__.py index fa5cab6..665fd3f 100644 --- a/src/asyncio_multilock/__init__.py +++ b/src/asyncio_multilock/__init__.py @@ -12,8 +12,6 @@ Optional, ) -__sentinel__ = object() - class LockError(Exception): ... @@ -25,7 +23,7 @@ class HandleUsedError(LockError): ... class MultiLockType(IntEnum): - # 0 not used so that type is always truthy. + NONE = 0 SHARED = 1 EXCLUSIVE = 2 @@ -35,6 +33,18 @@ def max(a: MultiLockType, b: MultiLockType) -> MultiLockType: return b return a + def excludes(self, other: MultiLockType) -> bool: + if self is MultiLockType.NONE: + return False + + return MultiLockType.EXCLUSIVE in (self, other) + + def includes(self, other: MultiLockType) -> bool: + if self is MultiLockType.NONE: + return True + + return MultiLockType.EXCLUSIVE not in (self, other) + class MultiLock: """Shared/exclusive mode lock. @@ -55,20 +65,20 @@ class MultiLock: If not given they will be created and returned as appropriate. """ - def __init__(self): - self._locked: Optional[MultiLockType] = __sentinel__ # type: ignore + def __init__(self) -> None: + self._locked: Optional[MultiLockType] = None self._acquire: Dict[Hashable, MultiLockType] = {} self._notify: Dict[Event, MultiLockType] = {} @property - def locked(self) -> Optional[MultiLockType]: + def locked(self) -> MultiLockType: """Lock state. None if not locked. Acquired lock type otherwise. """ - if self._locked is __sentinel__: + if self._locked is None: self._locked = ( - None + MultiLockType.NONE if not self._acquire else reduce(MultiLockType.max, self._acquire.values()) ) @@ -90,7 +100,7 @@ def notify( raise EventUsedError(f"{id(self)}: event {id(event)} already in notify") self._notify[event] = type try: - if not self.locked or MultiLockType.EXCLUSIVE not in (self.locked, type): + if self.locked.includes(type): event.set() yield event finally: @@ -115,7 +125,7 @@ def acquire_nowait( :param handle: User-specified lock handle. :return: Handle if lock is acquired. None otherwise. """ - if self.locked and MultiLockType.EXCLUSIVE in (self.locked, type): + if self.locked.excludes(type): return None if handle is None: handle = object() @@ -124,7 +134,7 @@ def acquire_nowait( f"{id(self)}: handle {id(handle)} already in acquired" ) self._acquire[handle] = type - self._locked = __sentinel__ # type: ignore + self._locked = None return handle async def acquire( @@ -151,9 +161,9 @@ async def acquire( def release(self, handle: Hashable) -> None: self._acquire.pop(handle, None) # OK to release unknown handle. - self._locked = __sentinel__ # type: ignore + self._locked = None for handle, type in self._notify.items(): - if not self.locked or MultiLockType.EXCLUSIVE not in (self.locked, type): + if self.locked.includes(type): handle.set() @asynccontextmanager diff --git a/tests/test_multilock.py b/tests/test_multilock.py index cdf4141..c27671a 100644 --- a/tests/test_multilock.py +++ b/tests/test_multilock.py @@ -38,7 +38,9 @@ def test_acquire_nowait_fail_when_exclusive(type: MultiLockType) -> None: assert not lock.locked -@mark.parametrize(["type"], [param(type) for type in MultiLockType]) +@mark.parametrize( + ["type"], [param(MultiLockType.SHARED), param(MultiLockType.EXCLUSIVE)] +) def test_acquire_nowait_fail_exclusive_when_locked(type: MultiLockType) -> None: lock = MultiLock() acquired = lock.acquire_nowait(type) @@ -106,7 +108,9 @@ async def test_acquire_wait_when_exclusive(type: MultiLockType) -> None: @mark.timeout(3) -@mark.parametrize(["type"], [param(type) for type in MultiLockType]) +@mark.parametrize( + ["type"], [param(MultiLockType.SHARED), param(MultiLockType.EXCLUSIVE)] +) async def test_acquire_wait_exclusive_when_locked(type: MultiLockType): lock = MultiLock() acquired = await lock.acquire(type) @@ -169,7 +173,9 @@ async def test_notify_wait_when_exclusive(type: MultiLockType): @mark.timeout(3) -@mark.parametrize(["type"], [param(type) for type in MultiLockType]) +@mark.parametrize( + ["type"], [param(MultiLockType.SHARED), param(MultiLockType.EXCLUSIVE)] +) async def test_notify_wait_exclusive_when_locked(type: MultiLockType): lock = MultiLock() handle = lock.acquire_nowait(type)