Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refac: Engine now owns the queue; Attempt to simplify engines (~15 lines less per engine) #498

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion statemachine/contrib/diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _actions_getter(self):
if isinstance(self.machine, StateMachine):

def getter(grouper) -> str:
return self.machine._callbacks_registry.str(grouper.key)
return self.machine._callbacks.str(grouper.key)
else:

def getter(grouper) -> str:
Expand Down
68 changes: 27 additions & 41 deletions statemachine/engines/async_.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
from threading import Lock
from typing import TYPE_CHECKING
from weakref import proxy

from ..event_data import EventData
from ..event_data import TriggerData
from ..exceptions import InvalidDefinition
from ..exceptions import TransitionNotAllowed
from ..i18n import _
from ..state import State
from ..transition import Transition
from .base import BaseEngine

if TYPE_CHECKING:
from ..statemachine import StateMachine
from ..transition import Transition


class AsyncEngine:
class AsyncEngine(BaseEngine):
def __init__(self, sm: "StateMachine", rtc: bool = True):
sm._engine = self
self.sm = proxy(sm)
self._sentinel = object()
if not rtc:
raise InvalidDefinition(_("Only RTC is supported on async engine"))
self._processing = Lock()
super().__init__(sm=sm, rtc=rtc)

async def activate_initial_state(self):
"""
Expand Down Expand Up @@ -65,80 +60,71 @@ async def processing_loop(self):
first_result = self._sentinel
try:
# Execute the triggers in the queue in FIFO order until the queue is empty
while self.sm._external_queue:
trigger_data = self.sm._external_queue.popleft()
while self._external_queue:
trigger_data = self._external_queue.popleft()
try:
result = await self._trigger(trigger_data)
if first_result is self._sentinel:
first_result = result
except Exception:
# Whe clear the queue as we don't have an expected behavior
# and cannot keep processing
self.sm._external_queue.clear()
self._external_queue.clear()
raise
finally:
self._processing.release()
return first_result if first_result is not self._sentinel else None

async def _trigger(self, trigger_data: TriggerData):
event_data = None
executed = False
if trigger_data.event == "__initial__":
transition = Transition(State(), self.sm._get_initial_state(), event="__initial__")
transition._specs.clear()
event_data = EventData(trigger_data=trigger_data, transition=transition)
await self._activate(event_data)
transition = self._initial_transition(trigger_data)
await self._activate(trigger_data, transition)
return self._sentinel

state = self.sm.current_state
for transition in state.transitions:
if not transition.match(trigger_data.event):
continue

event_data = EventData(trigger_data=trigger_data, transition=transition)
args, kwargs = event_data.args, event_data.extended_kwargs
await self.sm._callbacks_registry.async_call(
transition.validators.key, *args, **kwargs
)
if not await self.sm._callbacks_registry.async_all(
transition.cond.key, *args, **kwargs
):
executed, result = await self._activate(trigger_data, transition)
if not executed:
continue

result = await self._activate(event_data)
event_data.result = result
event_data.executed = True
break
else:
if not self.sm.allow_event_without_transition:
raise TransitionNotAllowed(trigger_data.event, state)

return event_data.result if event_data else None
return result if executed else None

async def _activate(self, event_data: EventData):
async def _activate(self, trigger_data: TriggerData, transition: "Transition"):
event_data = EventData(trigger_data=trigger_data, transition=transition)
args, kwargs = event_data.args, event_data.extended_kwargs
transition = event_data.transition
source = event_data.state

await self.sm._callbacks.async_call(transition.validators.key, *args, **kwargs)
if not await self.sm._callbacks.async_all(transition.cond.key, *args, **kwargs):
return False, None

source = transition.source
target = transition.target

result = await self.sm._callbacks_registry.async_call(
transition.before.key, *args, **kwargs
)
result = await self.sm._callbacks.async_call(transition.before.key, *args, **kwargs)
if source is not None and not transition.internal:
await self.sm._callbacks_registry.async_call(source.exit.key, *args, **kwargs)
await self.sm._callbacks.async_call(source.exit.key, *args, **kwargs)

result += await self.sm._callbacks_registry.async_call(transition.on.key, *args, **kwargs)
result += await self.sm._callbacks.async_call(transition.on.key, *args, **kwargs)

self.sm.current_state = target
event_data.state = target
kwargs["state"] = target

if not transition.internal:
await self.sm._callbacks_registry.async_call(target.enter.key, *args, **kwargs)
await self.sm._callbacks_registry.async_call(transition.after.key, *args, **kwargs)
await self.sm._callbacks.async_call(target.enter.key, *args, **kwargs)
await self.sm._callbacks.async_call(transition.after.key, *args, **kwargs)

if len(result) == 0:
result = None
elif len(result) == 1:
result = result[0]

return result
return True, result
41 changes: 41 additions & 0 deletions statemachine/engines/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from collections import deque
from threading import Lock
from typing import TYPE_CHECKING
from weakref import proxy

from statemachine.event import BoundEvent

from ..event_data import TriggerData
from ..state import State
from ..transition import Transition

if TYPE_CHECKING:
from ..statemachine import StateMachine


class BaseEngine:
def __init__(self, sm: "StateMachine", rtc: bool = True):
self.sm: StateMachine = proxy(sm)
self._external_queue: deque = deque()
self._sentinel = object()
self._rtc = rtc
self._processing = Lock()

def put(self, trigger_data: TriggerData):
"""Put the trigger on the queue without blocking the caller."""
self._external_queue.append(trigger_data)

def start(self):
if self.sm.current_state_value is not None:
return

trigger_data = TriggerData(
machine=self.sm,
event=BoundEvent("__initial__", _sm=self.sm),
)
self.put(trigger_data)

def _initial_transition(self, trigger_data):
transition = Transition(State(), self.sm._get_initial_state(), event="__initial__")
transition._specs.clear()
return transition
67 changes: 29 additions & 38 deletions statemachine/engines/sync.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
from threading import Lock
from typing import TYPE_CHECKING
from weakref import proxy

from ..event_data import EventData
from ..event_data import TriggerData
from ..exceptions import TransitionNotAllowed
from ..state import State
from ..transition import Transition
from .base import BaseEngine

if TYPE_CHECKING:
from ..statemachine import StateMachine
from ..transition import Transition


class SyncEngine:
def __init__(self, sm: "StateMachine", rtc: bool = True):
sm._engine = self
self.sm = proxy(sm)
self._sentinel = object()
self._rtc = rtc
self._processing = Lock()
class SyncEngine(BaseEngine):
def start(self):
super().start()
self.activate_initial_state()

def activate_initial_state(self):
Expand Down Expand Up @@ -54,7 +47,7 @@ def processing_loop(self):
"""
if not self._rtc:
# The machine is in "synchronous" mode
trigger_data = self.sm._external_queue.popleft()
trigger_data = self._external_queue.popleft()
return self._trigger(trigger_data)

# We make sure that only the first event enters the processing critical section,
Expand All @@ -68,74 +61,72 @@ def processing_loop(self):
first_result = self._sentinel
try:
# Execute the triggers in the queue in FIFO order until the queue is empty
while self.sm._external_queue:
trigger_data = self.sm._external_queue.popleft()
while self._external_queue:
trigger_data = self._external_queue.popleft()
try:
result = self._trigger(trigger_data)
if first_result is self._sentinel:
first_result = result
except Exception:
# Whe clear the queue as we don't have an expected behavior
# and cannot keep processing
self.sm._external_queue.clear()
self._external_queue.clear()
raise
finally:
self._processing.release()
return first_result if first_result is not self._sentinel else None

def _trigger(self, trigger_data: TriggerData):
event_data = None
executed = False
if trigger_data.event == "__initial__":
transition = Transition(State(), self.sm._get_initial_state(), event="__initial__")
transition._specs.clear()
event_data = EventData(trigger_data=trigger_data, transition=transition)
self._activate(event_data)
transition = self._initial_transition(trigger_data)
self._activate(trigger_data, transition)
return self._sentinel

state = self.sm.current_state
for transition in state.transitions:
if not transition.match(trigger_data.event):
continue

event_data = EventData(trigger_data=trigger_data, transition=transition)
args, kwargs = event_data.args, event_data.extended_kwargs
self.sm._callbacks_registry.call(transition.validators.key, *args, **kwargs)
if not self.sm._callbacks_registry.all(transition.cond.key, *args, **kwargs):
executed, result = self._activate(trigger_data, transition)
if not executed:
continue

result = self._activate(event_data)
event_data.result = result
event_data.executed = True
break
else:
if not self.sm.allow_event_without_transition:
raise TransitionNotAllowed(trigger_data.event, state)

return event_data.result if event_data else None
return result if executed else None

def _activate(self, event_data: EventData):
def _activate(self, trigger_data: TriggerData, transition: "Transition"):
event_data = EventData(trigger_data=trigger_data, transition=transition)
args, kwargs = event_data.args, event_data.extended_kwargs
transition = event_data.transition
source = event_data.state

self.sm._callbacks.call(transition.validators.key, *args, **kwargs)
if not self.sm._callbacks.all(transition.cond.key, *args, **kwargs):
return False, None

source = transition.source
target = transition.target

result = self.sm._callbacks_registry.call(transition.before.key, *args, **kwargs)
result = self.sm._callbacks.call(transition.before.key, *args, **kwargs)
if source is not None and not transition.internal:
self.sm._callbacks_registry.call(source.exit.key, *args, **kwargs)
self.sm._callbacks.call(source.exit.key, *args, **kwargs)

result += self.sm._callbacks_registry.call(transition.on.key, *args, **kwargs)
result += self.sm._callbacks.call(transition.on.key, *args, **kwargs)

self.sm.current_state = target
event_data.state = target
kwargs["state"] = target

if not transition.internal:
self.sm._callbacks_registry.call(target.enter.key, *args, **kwargs)
self.sm._callbacks_registry.call(transition.after.key, *args, **kwargs)
self.sm._callbacks.call(target.enter.key, *args, **kwargs)
self.sm._callbacks.call(transition.after.key, *args, **kwargs)

if len(result) == 0:
result = None
elif len(result) == 1:
result = result[0]

return result
return True, result
Loading