Skip to content

Commit

Permalink
Implement MultiLockSet
Browse files Browse the repository at this point in the history
Fix MultiLockType.max bug where actual operation was min (oops...).
  • Loading branch information
phyrwork committed Sep 3, 2024
1 parent 7f510e0 commit 77d123c
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 170 deletions.
159 changes: 76 additions & 83 deletions src/asyncio_multilock/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from asyncio import Event
from contextlib import asynccontextmanager, contextmanager
from contextlib import asynccontextmanager, contextmanager, ExitStack
from enum import IntEnum
from functools import reduce
from typing import (
Expand All @@ -10,6 +10,7 @@
Hashable,
Iterator,
Optional,
FrozenSet,
)


Expand All @@ -29,9 +30,7 @@ class MultiLockType(IntEnum):

@staticmethod
def max(a: MultiLockType, b: MultiLockType) -> MultiLockType:
if a > b:
return b
return a
return a if a > b else b

def excludes(self, other: MultiLockType) -> bool:
if self is MultiLockType.NONE:
Expand Down Expand Up @@ -77,10 +76,8 @@ def locked(self) -> MultiLockType:
None if not locked. Acquired lock type otherwise.
"""
if self._locked is None:
self._locked = (
MultiLockType.NONE
if not self._acquire
else reduce(MultiLockType.max, self._acquire.values())
self._locked = reduce(
MultiLockType.max, self._acquire.values(), MultiLockType.NONE
)
return self._locked

Expand All @@ -94,7 +91,7 @@ def notify(
:param event: User-specified lock type acquirable event.
:return: Lock type acquirable event.
"""
if not event:
if event is None:
event = Event()
elif event in self._notify:
raise EventUsedError(f"{id(self)}: event {id(event)} already in notify")
Expand Down Expand Up @@ -187,77 +184,73 @@ async def context(
self.release(handle)


# class MultiLockSet(FrozenSet[MultiLock]):
# @property
# def locked(self) -> Optional[MultiLockType]:
# if not self:
# return None
#
# def max(
# a: MultiLockType | None, b: MultiLockType | None
# ) -> MultiLockType | None:
# if a is None:
# return b
# if b is None:
# return a
# return MultiLockType.max(a, b)
#
# return reduce(max, (lock.locked for lock in self))
#
# @asynccontextmanager
# async def notify(
# self, type: MultiLockType, event: Optional[Event] = None
# ) -> AsyncIterator[Event]:
# if event is None:
# event = Event()
# _events = Event()
#
# async def listen() -> None:
# assert event
# _events.clear()
# locked = self.locked
# if not locked or MultiLockType.EXCLUSIVE not in (locked, type):
# event.set()
#
# listener = create_task(listen())
#
# try:
# with ExitStack() as stack:
# for lock in self:
# stack.enter_context(lock.notify(type, _events))
# yield event
# finally:
# listener.cancel()
# with suppress(CancelledError):
# await listener
#
# def acquire_nowait(
# self, type: MultiLockType, handle: Optional[Hashable] = None
# ) -> Optional[Hashable]:
# locked = self.locked
# if locked and MultiLockType.EXCLUSIVE in (locked, type):
# return None
# if handle is None:
# handle = object()
# for lock in self:
# assert lock.acquire_nowait(type, handle)
# return handle
#
# async def acquire(
# self,
# \type: MultiLockType,
# handle: Optional[Hashable] = None,
# event: Optional[Event] = None,
# ) -> Hashable:
# async with self.notify(type, event) as event:
# assert event
# if not handle:
# handle = object()
# while not self.acquire_nowait(type, handle):
# await event.wait()
# event.clear()
# return handle
#
# def release(self, handle: Hashable) -> None:
# for lock in self:
# lock.release(handle)
class MultiLockSet(FrozenSet[MultiLock]):
@property
def locked(self) -> MultiLockType:
return reduce(
MultiLockType.max, (lock.locked for lock in self), MultiLockType.NONE
)

@contextmanager
def notify(
self, type: MultiLockType, event: Optional[Event] = None
) -> Iterator[Event]:
if event is None:
event = Event()

with ExitStack() as stack:
for lock in self:
stack.enter_context(lock.notify(type, event))

yield event

async def wait(self, type: MultiLockType, event: Optional[Event]) -> None:
with self.notify(type, event) as event:
while await event.wait():
if self.locked.includes(type):
return

def acquire_nowait(
self, type: MultiLockType, handle: Optional[Hashable] = None
) -> Optional[Hashable]:
if self.locked.excludes(type):
return None

if handle is None:
handle = object()

with ExitStack() as stack:
for lock in self:
lock.acquire_nowait(type, handle)

def release() -> None:
lock.release(handle)

stack.callback(release)

stack.pop_all()

return handle

async def acquire(
self,
type: MultiLockType,
handle: Optional[Hashable] = None,
event: Optional[Event] = None,
) -> Hashable:
if not handle:
handle = object()

if event is None:
event = Event()

with self.notify(type, event) as event:
while not self.acquire_nowait(type, handle):
await event.wait()
event.clear()

return handle

def release(self, handle: Hashable) -> None:
for lock in self:
lock.release(handle)
Loading

0 comments on commit 77d123c

Please sign in to comment.