From c2ff71dfd6b0b2fe42f87ed7b39daee91023c39e Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Wed, 20 Nov 2024 14:10:03 -0300 Subject: [PATCH] refac: Engine now owns the queue; Attempt to simplify engines (~15 lines less per engine) --- statemachine/contrib/diagram.py | 2 +- statemachine/engines/async_.py | 68 +++++++++++++-------------------- statemachine/engines/base.py | 41 ++++++++++++++++++++ statemachine/engines/sync.py | 67 ++++++++++++++------------------ statemachine/statemachine.py | 39 ++++++++----------- tests/conftest.py | 4 +- tests/test_transitions.py | 4 +- 7 files changed, 117 insertions(+), 108 deletions(-) create mode 100644 statemachine/engines/base.py diff --git a/statemachine/contrib/diagram.py b/statemachine/contrib/diagram.py index dd7d1d42..ee0d14f4 100644 --- a/statemachine/contrib/diagram.py +++ b/statemachine/contrib/diagram.py @@ -70,7 +70,7 @@ def _actions_getter(self): if isinstance(self.machine, StateMachine): def getter(grouper) -> str: - return self.machine._callbacks_registry.str(grouper.key) + return self.machine._callbacks.str(grouper.key) else: def getter(grouper) -> str: diff --git a/statemachine/engines/async_.py b/statemachine/engines/async_.py index 3425e437..9d2b3f9f 100644 --- a/statemachine/engines/async_.py +++ b/statemachine/engines/async_.py @@ -1,27 +1,22 @@ -from threading import Lock from typing import TYPE_CHECKING -from weakref import proxy from ..event_data import EventData from ..event_data import TriggerData from ..exceptions import InvalidDefinition from ..exceptions import TransitionNotAllowed from ..i18n import _ -from ..state import State -from ..transition import Transition +from .base import BaseEngine if TYPE_CHECKING: from ..statemachine import StateMachine + from ..transition import Transition -class AsyncEngine: +class AsyncEngine(BaseEngine): def __init__(self, sm: "StateMachine", rtc: bool = True): - sm._engine = self - self.sm = proxy(sm) - self._sentinel = object() if not rtc: raise InvalidDefinition(_("Only RTC is supported on async engine")) - self._processing = Lock() + super().__init__(sm=sm, rtc=rtc) async def activate_initial_state(self): """ @@ -65,8 +60,8 @@ async def processing_loop(self): first_result = self._sentinel try: # Execute the triggers in the queue in FIFO order until the queue is empty - while self.sm._external_queue: - trigger_data = self.sm._external_queue.popleft() + while self._external_queue: + trigger_data = self._external_queue.popleft() try: result = await self._trigger(trigger_data) if first_result is self._sentinel: @@ -74,19 +69,17 @@ async def processing_loop(self): except Exception: # Whe clear the queue as we don't have an expected behavior # and cannot keep processing - self.sm._external_queue.clear() + self._external_queue.clear() raise finally: self._processing.release() return first_result if first_result is not self._sentinel else None async def _trigger(self, trigger_data: TriggerData): - event_data = None + executed = False if trigger_data.event == "__initial__": - transition = Transition(State(), self.sm._get_initial_state(), event="__initial__") - transition._specs.clear() - event_data = EventData(trigger_data=trigger_data, transition=transition) - await self._activate(event_data) + transition = self._initial_transition(trigger_data) + await self._activate(trigger_data, transition) return self._sentinel state = self.sm.current_state @@ -94,51 +87,44 @@ async def _trigger(self, trigger_data: TriggerData): if not transition.match(trigger_data.event): continue - event_data = EventData(trigger_data=trigger_data, transition=transition) - args, kwargs = event_data.args, event_data.extended_kwargs - await self.sm._callbacks_registry.async_call( - transition.validators.key, *args, **kwargs - ) - if not await self.sm._callbacks_registry.async_all( - transition.cond.key, *args, **kwargs - ): + executed, result = await self._activate(trigger_data, transition) + if not executed: continue - - result = await self._activate(event_data) - event_data.result = result - event_data.executed = True break else: if not self.sm.allow_event_without_transition: raise TransitionNotAllowed(trigger_data.event, state) - return event_data.result if event_data else None + return result if executed else None - async def _activate(self, event_data: EventData): + async def _activate(self, trigger_data: TriggerData, transition: "Transition"): + event_data = EventData(trigger_data=trigger_data, transition=transition) args, kwargs = event_data.args, event_data.extended_kwargs - transition = event_data.transition - source = event_data.state + + await self.sm._callbacks.async_call(transition.validators.key, *args, **kwargs) + if not await self.sm._callbacks.async_all(transition.cond.key, *args, **kwargs): + return False, None + + source = transition.source target = transition.target - result = await self.sm._callbacks_registry.async_call( - transition.before.key, *args, **kwargs - ) + result = await self.sm._callbacks.async_call(transition.before.key, *args, **kwargs) if source is not None and not transition.internal: - await self.sm._callbacks_registry.async_call(source.exit.key, *args, **kwargs) + await self.sm._callbacks.async_call(source.exit.key, *args, **kwargs) - result += await self.sm._callbacks_registry.async_call(transition.on.key, *args, **kwargs) + result += await self.sm._callbacks.async_call(transition.on.key, *args, **kwargs) self.sm.current_state = target event_data.state = target kwargs["state"] = target if not transition.internal: - await self.sm._callbacks_registry.async_call(target.enter.key, *args, **kwargs) - await self.sm._callbacks_registry.async_call(transition.after.key, *args, **kwargs) + await self.sm._callbacks.async_call(target.enter.key, *args, **kwargs) + await self.sm._callbacks.async_call(transition.after.key, *args, **kwargs) if len(result) == 0: result = None elif len(result) == 1: result = result[0] - return result + return True, result diff --git a/statemachine/engines/base.py b/statemachine/engines/base.py new file mode 100644 index 00000000..9abc1fe5 --- /dev/null +++ b/statemachine/engines/base.py @@ -0,0 +1,41 @@ +from collections import deque +from threading import Lock +from typing import TYPE_CHECKING +from weakref import proxy + +from statemachine.event import BoundEvent + +from ..event_data import TriggerData +from ..state import State +from ..transition import Transition + +if TYPE_CHECKING: + from ..statemachine import StateMachine + + +class BaseEngine: + def __init__(self, sm: "StateMachine", rtc: bool = True): + self.sm: StateMachine = proxy(sm) + self._external_queue: deque = deque() + self._sentinel = object() + self._rtc = rtc + self._processing = Lock() + + def put(self, trigger_data: TriggerData): + """Put the trigger on the queue without blocking the caller.""" + self._external_queue.append(trigger_data) + + def start(self): + if self.sm.current_state_value is not None: + return + + trigger_data = TriggerData( + machine=self.sm, + event=BoundEvent("__initial__", _sm=self.sm), + ) + self.put(trigger_data) + + def _initial_transition(self, trigger_data): + transition = Transition(State(), self.sm._get_initial_state(), event="__initial__") + transition._specs.clear() + return transition diff --git a/statemachine/engines/sync.py b/statemachine/engines/sync.py index 32e00bf2..4400cd08 100644 --- a/statemachine/engines/sync.py +++ b/statemachine/engines/sync.py @@ -1,24 +1,17 @@ -from threading import Lock from typing import TYPE_CHECKING -from weakref import proxy from ..event_data import EventData from ..event_data import TriggerData from ..exceptions import TransitionNotAllowed -from ..state import State -from ..transition import Transition +from .base import BaseEngine if TYPE_CHECKING: - from ..statemachine import StateMachine + from ..transition import Transition -class SyncEngine: - def __init__(self, sm: "StateMachine", rtc: bool = True): - sm._engine = self - self.sm = proxy(sm) - self._sentinel = object() - self._rtc = rtc - self._processing = Lock() +class SyncEngine(BaseEngine): + def start(self): + super().start() self.activate_initial_state() def activate_initial_state(self): @@ -54,7 +47,7 @@ def processing_loop(self): """ if not self._rtc: # The machine is in "synchronous" mode - trigger_data = self.sm._external_queue.popleft() + trigger_data = self._external_queue.popleft() return self._trigger(trigger_data) # We make sure that only the first event enters the processing critical section, @@ -68,8 +61,8 @@ def processing_loop(self): first_result = self._sentinel try: # Execute the triggers in the queue in FIFO order until the queue is empty - while self.sm._external_queue: - trigger_data = self.sm._external_queue.popleft() + while self._external_queue: + trigger_data = self._external_queue.popleft() try: result = self._trigger(trigger_data) if first_result is self._sentinel: @@ -77,19 +70,17 @@ def processing_loop(self): except Exception: # Whe clear the queue as we don't have an expected behavior # and cannot keep processing - self.sm._external_queue.clear() + self._external_queue.clear() raise finally: self._processing.release() return first_result if first_result is not self._sentinel else None def _trigger(self, trigger_data: TriggerData): - event_data = None + executed = False if trigger_data.event == "__initial__": - transition = Transition(State(), self.sm._get_initial_state(), event="__initial__") - transition._specs.clear() - event_data = EventData(trigger_data=trigger_data, transition=transition) - self._activate(event_data) + transition = self._initial_transition(trigger_data) + self._activate(trigger_data, transition) return self._sentinel state = self.sm.current_state @@ -97,45 +88,45 @@ def _trigger(self, trigger_data: TriggerData): if not transition.match(trigger_data.event): continue - event_data = EventData(trigger_data=trigger_data, transition=transition) - args, kwargs = event_data.args, event_data.extended_kwargs - self.sm._callbacks_registry.call(transition.validators.key, *args, **kwargs) - if not self.sm._callbacks_registry.all(transition.cond.key, *args, **kwargs): + executed, result = self._activate(trigger_data, transition) + if not executed: continue - result = self._activate(event_data) - event_data.result = result - event_data.executed = True break else: if not self.sm.allow_event_without_transition: raise TransitionNotAllowed(trigger_data.event, state) - return event_data.result if event_data else None + return result if executed else None - def _activate(self, event_data: EventData): + def _activate(self, trigger_data: TriggerData, transition: "Transition"): + event_data = EventData(trigger_data=trigger_data, transition=transition) args, kwargs = event_data.args, event_data.extended_kwargs - transition = event_data.transition - source = event_data.state + + self.sm._callbacks.call(transition.validators.key, *args, **kwargs) + if not self.sm._callbacks.all(transition.cond.key, *args, **kwargs): + return False, None + + source = transition.source target = transition.target - result = self.sm._callbacks_registry.call(transition.before.key, *args, **kwargs) + result = self.sm._callbacks.call(transition.before.key, *args, **kwargs) if source is not None and not transition.internal: - self.sm._callbacks_registry.call(source.exit.key, *args, **kwargs) + self.sm._callbacks.call(source.exit.key, *args, **kwargs) - result += self.sm._callbacks_registry.call(transition.on.key, *args, **kwargs) + result += self.sm._callbacks.call(transition.on.key, *args, **kwargs) self.sm.current_state = target event_data.state = target kwargs["state"] = target if not transition.internal: - self.sm._callbacks_registry.call(target.enter.key, *args, **kwargs) - self.sm._callbacks_registry.call(transition.after.key, *args, **kwargs) + self.sm._callbacks.call(target.enter.key, *args, **kwargs) + self.sm._callbacks.call(transition.after.key, *args, **kwargs) if len(result) == 0: result = None elif len(result) == 1: result = result[0] - return result + return True, result diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index e5f64c83..cadf0e2d 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -1,5 +1,4 @@ import warnings -from collections import deque from copy import deepcopy from inspect import isawaitable from threading import Lock @@ -82,8 +81,7 @@ def __init__( self.state_field = state_field self.start_value = start_value self.allow_event_without_transition = allow_event_without_transition - self._external_queue: deque = deque() - self._callbacks_registry = CallbacksRegistry() + self._callbacks = CallbacksRegistry() self._states_for_instance: Dict[State, State] = {} self._listeners: Dict[Any, Any] = {} @@ -97,21 +95,14 @@ def __init__( # Activate the initial state, this only works if the outer scope is sync code. # for async code, the user should manually call `await sm.activate_initial_state()` # after state machine creation. - if self.current_state_value is None: - trigger_data = TriggerData( - machine=self, - event=BoundEvent("__initial__", _sm=self), - ) - self._put_nonblocking(trigger_data) + self._engine = self._get_engine(rtc) + self._engine.start() - self._engine: AsyncEngine | SyncEngine | None = None - self._select_engine(rtc) + def _get_engine(self, rtc: bool): + if self._callbacks.has_async_callbacks: + return AsyncEngine(self, rtc=rtc) - def _select_engine(self, rtc: bool): - if self._callbacks_registry.has_async_callbacks: - AsyncEngine(self, rtc=rtc) - else: - SyncEngine(self, rtc=rtc) + return SyncEngine(self, rtc=rtc) def activate_initial_state(self): result = self._engine.activate_initial_state() @@ -151,17 +142,17 @@ def __deepcopy__(self, memo): self.__deepcopy__ = deepcopy_method cp.__deepcopy__ = deepcopy_method self._engine._processing = lock - cp._callbacks_registry.clear() + cp._callbacks.clear() cp._register_callbacks([]) cp.add_listener(*cp._listeners.keys()) return cp def _get_initial_state(self): - current_state_value = self.start_value if self.start_value else self.initial_state.value + initial_state_value = self.start_value if self.start_value else self.initial_state.value try: - return self.states_map[current_state_value] + return self.states_map[initial_state_value] except KeyError as err: - raise InvalidStateValue(current_state_value) from err + raise InvalidStateValue(initial_state_value) from err def bind_events_to(self, *targets): """Bind the state machine events to the target objects.""" @@ -179,7 +170,7 @@ def bind_events_to(self, *targets): setattr(target, event, trigger) def _add_listener(self, listeners: "Listeners", allowed_references: SpecReference = SPECS_ALL): - registry = self._callbacks_registry + registry = self._callbacks for visited in iterate_states_and_transitions(self.states): listeners.resolve( visited._specs, @@ -201,7 +192,7 @@ def _register_callbacks(self, listeners: List[object]): ) ) - check_callbacks = self._callbacks_registry.check + check_callbacks = self._callbacks.check for visited in iterate_states_and_transitions(self.states): try: check_callbacks(visited._specs) @@ -210,7 +201,7 @@ def _register_callbacks(self, listeners: List[object]): f"Error on {visited!s} when resolving callbacks: {err}" ) from err - self._callbacks_registry.async_or_sync() + self._callbacks.async_or_sync() def add_observer(self, *observers): """Add a listener.""" @@ -304,7 +295,7 @@ def allowed_events(self) -> "List[Event]": def _put_nonblocking(self, trigger_data: TriggerData): """Put the trigger on the queue without blocking the caller.""" - self._external_queue.append(trigger_data) + self._engine.put(trigger_data) def send(self, event: str, *args, **kwargs): """Send an :ref:`Event` to the state machine. diff --git a/tests/conftest.py b/tests/conftest.py index 6ad318eb..a720a808 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -135,8 +135,8 @@ class TrafficLightMachine(StateMachine): stop = yellow.to(red) go = red.to(green) - def _select_engine(self, rtc: bool): - engine(self, rtc) + def _get_engine(self, rtc: bool): + return engine(self, rtc) return TrafficLightMachine diff --git a/tests/test_transitions.py b/tests/test_transitions.py index 3f2c8150..5f975db4 100644 --- a/tests/test_transitions.py +++ b/tests/test_transitions.py @@ -252,8 +252,8 @@ class TestStateMachine(StateMachine): loop = initial.to.itself(internal=internal) - def _select_engine(self, rtc: bool): - engine(self, rtc) + def _get_engine(self, rtc: bool): + return engine(self, rtc) def on_exit_initial(self): calls.append("on_exit_initial")