Skip to content

Commit

Permalink
feat: Internal transitions
Browse files Browse the repository at this point in the history
  • Loading branch information
fgmacedo committed Dec 10, 2024
1 parent 44994ed commit 8030a84
Show file tree
Hide file tree
Showing 20 changed files with 158 additions and 57 deletions.
6 changes: 4 additions & 2 deletions statemachine/engines/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ async def _trigger(self, trigger_data: TriggerData):
await self._activate(trigger_data, transition)
return self._sentinel

state = self.sm.current_state
# TODO: Fix async engine
state = next(iter(self.sm.configuration))

for transition in state.transitions:
if not transition.match(trigger_data.event):
continue
Expand All @@ -83,7 +85,7 @@ async def _trigger(self, trigger_data: TriggerData):
break
else:
if not self.sm.allow_event_without_transition:
raise TransitionNotAllowed(trigger_data.event, state)
raise TransitionNotAllowed(trigger_data.event, self.sm.configuration)

return result if executed else None

Expand Down
78 changes: 60 additions & 18 deletions statemachine/engines/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
from dataclasses import dataclass
from itertools import chain
from queue import PriorityQueue
from queue import Queue
from threading import Lock
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import cast
from weakref import proxy

from ..event import BoundEvent
Expand All @@ -20,6 +24,8 @@
if TYPE_CHECKING:
from ..statemachine import StateMachine

logger = logging.getLogger(__name__)


@dataclass(frozen=True, unsafe_hash=True)
class StateTransition:
Expand Down Expand Up @@ -76,7 +82,7 @@ def empty(self):
def put(self, trigger_data: TriggerData, internal: bool = False):
"""Put the trigger on the queue without blocking the caller."""
if not self.running and not self.sm.allow_event_without_transition:
raise TransitionNotAllowed(trigger_data.event, self.sm.current_state)
raise TransitionNotAllowed(trigger_data.event, self.sm.configuration)

if internal:
self.internal_queue.put(trigger_data)
Expand Down Expand Up @@ -117,7 +123,9 @@ def _conditions_match(self, transition: Transition, trigger_data: TriggerData):
self.sm._callbacks.call(transition.validators.key, *args, **kwargs)
return self.sm._callbacks.all(transition.cond.key, *args, **kwargs)

def _filter_conflicting_transitions(self, transitions: OrderedSet[Transition]):
def _filter_conflicting_transitions(
self, transitions: OrderedSet[Transition]
) -> OrderedSet[Transition]:
"""
Remove transições conflitantes, priorizando aquelas com estados de origem descendentes
ou que aparecem antes na ordem do documento.
Expand All @@ -128,18 +136,18 @@ def _filter_conflicting_transitions(self, transitions: OrderedSet[Transition]):
Returns:
OrderedSet[Transition]: Conjunto de transições sem conflitos.
"""
filtered_transitions = OrderedSet()
filtered_transitions = OrderedSet[Transition]()

# Ordena as transições na ordem dos estados que as selecionaram
for t1 in transitions:
t1_preempted = False
transitions_to_remove = OrderedSet()
transitions_to_remove = OrderedSet[Transition]()

# Verifica conflitos com as transições já filtradas
for t2 in filtered_transitions:
# Calcula os conjuntos de saída (exit sets)
t1_exit_set = self._compute_exit_set(t1)
t2_exit_set = self._compute_exit_set(t2)
t1_exit_set = self._compute_exit_set([t1])
t2_exit_set = self._compute_exit_set([t2])

Check warning on line 150 in statemachine/engines/base.py

View check run for this annotation

Codecov / codecov/patch

statemachine/engines/base.py#L149-L150

Added lines #L149 - L150 were not covered by tests

# Verifica interseção dos conjuntos de saída
if t1_exit_set & t2_exit_set: # Há interseção
Expand All @@ -162,7 +170,7 @@ def _filter_conflicting_transitions(self, transitions: OrderedSet[Transition]):
def _compute_exit_set(self, transitions: List[Transition]) -> OrderedSet[StateTransition]:
"""Compute the exit set for a transition."""

states_to_exit = OrderedSet()
states_to_exit = OrderedSet[StateTransition]()

for transition in transitions:
if transition.target is None:
Expand Down Expand Up @@ -193,6 +201,8 @@ def get_transition_domain(self, transition: Transition) -> "State | None":
and all(state.is_descendant(transition.source) for state in states)
):
return transition.source
elif transition.internal and transition.is_self and transition.target.is_atomic:
return transition.source
else:
return self.find_lcca([transition.source] + list(states))

Expand All @@ -213,6 +223,7 @@ def find_lcca(states: List[State]) -> "State | None":
ancestors = [anc for anc in head.ancestors() if anc.is_compound]

# Find the first ancestor that is also an ancestor of all other states in the list
ancestor: State
for ancestor in ancestors:
if all(state.is_descendant(ancestor) for state in tail):
return ancestor
Expand All @@ -233,13 +244,16 @@ def _select_transitions(
self, trigger_data: TriggerData, predicate: Callable
) -> OrderedSet[Transition]:
"""Select the transitions that match the trigger data."""
enabled_transitions = OrderedSet()
enabled_transitions = OrderedSet[Transition]()

# Get atomic states, TODO: sorted by document order
atomic_states = (state for state in self.sm.configuration if state.is_atomic)

def first_transition_that_matches(state: State, event: Event) -> "Transition | None":
def first_transition_that_matches(
state: State, event: "Event | None"
) -> "Transition | None":
for s in chain([state], state.ancestors()):
transition: Transition
for transition in s.transitions:
if (
not transition.initial
Expand All @@ -248,6 +262,8 @@ def first_transition_that_matches(state: State, event: Event) -> "Transition | N
):
return transition

return None

for state in atomic_states:
transition = first_transition_that_matches(state, trigger_data.event)
if transition is not None:
Expand All @@ -264,6 +280,7 @@ def microstep(self, transitions: List[Transition], trigger_data: TriggerData):
)

states_to_exit = self._exit_states(transitions, trigger_data)
logger.debug("States to exit: %s", states_to_exit)
result += self._execute_transition_content(transitions, trigger_data, lambda t: t.on.key)
self._enter_states(transitions, trigger_data, states_to_exit)
self._execute_transition_content(
Expand Down Expand Up @@ -304,7 +321,7 @@ def _exit_states(self, enabled_transitions: List[Transition], trigger_data: Trig
# self.history_values[history.id] = history_value

# Execute `onexit` handlers
if info.source is not None and not info.transition.internal:
if info.source is not None: # TODO: and not info.transition.internal:
self.sm._callbacks.call(info.source.exit.key, *args, **kwargs)

# TODO: Cancel invocations
Expand Down Expand Up @@ -343,22 +360,29 @@ def _enter_states(
"""Enter the states as determined by the given transitions."""
states_to_enter = OrderedSet[StateTransition]()
states_for_default_entry = OrderedSet[StateTransition]()
default_history_content = {}
default_history_content: Dict[str, Any] = {}

# Compute the set of states to enter
self.compute_entry_set(
enabled_transitions, states_to_enter, states_for_default_entry, default_history_content
)

# We update the configuration atomically
states_targets_to_enter = OrderedSet(info.target for info in states_to_enter)
states_targets_to_enter = OrderedSet(
info.target for info in states_to_enter if info.target
)
logger.debug("States to enter: %s", states_targets_to_enter)

configuration = self.sm.configuration
self.sm.configuration = (configuration - states_to_exit) | states_targets_to_enter
self.sm.configuration = cast(
OrderedSet[State], (configuration - states_to_exit) | states_targets_to_enter
)

# Sort states to enter in entry order
# for state in sorted(states_to_enter, key=self.entry_order): # TODO: ordegin of states_to_enter # noqa: E501
for info in states_to_enter:
target = info.target
assert target
transition = info.transition
event_data = EventData(trigger_data=trigger_data, transition=transition)
event_data.state = target
Expand All @@ -376,8 +400,8 @@ def _enter_states(
# state.is_first_entry = False

# Execute `onentry` handlers
if not transition.internal:
self.sm._callbacks.call(target.enter.key, *args, **kwargs)
# TODO: if not transition.internal:
self.sm._callbacks.call(target.enter.key, *args, **kwargs)

# Handle default initial states
# TODO: Handle default initial states
Expand All @@ -396,11 +420,17 @@ def _enter_states(
parent = target.parent
grandparent = parent.parent

Check warning on line 421 in statemachine/engines/base.py

View check run for this annotation

Codecov / codecov/patch

statemachine/engines/base.py#L420-L421

Added lines #L420 - L421 were not covered by tests

self.internal_queue.put(BoundEvent(f"done.state.{parent.id}", _sm=self.sm))
self.internal_queue.put(

Check warning on line 423 in statemachine/engines/base.py

View check run for this annotation

Codecov / codecov/patch

statemachine/engines/base.py#L423

Added line #L423 was not covered by tests
BoundEvent(f"done.state.{parent.id}", _sm=self.sm).build_trigger(
machine=self.sm
)
)
if grandparent.parallel:
if all(child.final for child in grandparent.states):
self.internal_queue.put(

Check warning on line 430 in statemachine/engines/base.py

View check run for this annotation

Codecov / codecov/patch

statemachine/engines/base.py#L430

Added line #L430 was not covered by tests
BoundEvent(f"done.state.{parent.id}", _sm=self.sm)
BoundEvent(f"done.state.{parent.id}", _sm=self.sm).build_trigger(
machine=self.sm
)
)

def compute_entry_set(
Expand Down Expand Up @@ -476,8 +506,19 @@ def add_descendant_states_to_enter(
# return

# Add the state to the entry set
states_to_enter.add(info)
if (
not self.sm.enable_self_transition_entries
and info.transition.internal
and (
info.transition.is_self
or info.transition.target.is_descendant(info.transition.source)
)
):
pass
else:
states_to_enter.add(info)
state = info.target
assert state

if state.is_compound:
# Handle compound states
Expand Down Expand Up @@ -541,6 +582,7 @@ def add_ancestor_states_to_enter(
default_history_content: A dictionary to hold temporary content for history states.
"""
state = info.target
assert state
for anc in state.ancestors(parent=ancestor):
# Add the ancestor to the entry set
info_to_add = StateTransition(

Check warning on line 588 in statemachine/engines/base.py

View check run for this annotation

Codecov / codecov/patch

statemachine/engines/base.py#L588

Added line #L588 was not covered by tests
Expand Down
9 changes: 8 additions & 1 deletion statemachine/engines/sync.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from time import sleep
from time import time
from typing import TYPE_CHECKING
Expand All @@ -12,6 +13,8 @@
if TYPE_CHECKING:
from ..transition import Transition

logger = logging.getLogger(__name__)


class SyncEngine(BaseEngine):
def start(self):
Expand Down Expand Up @@ -58,6 +61,7 @@ def processing_loop(self): # noqa: C901
# We will collect the first result as the processing result to keep backwards compatibility
# so we need to use a sentinel object instead of `None` because the first result may
# be also `None`, and on this case the `first_result` may be overridden by another result.
logger.debug("Processing loop started: %s", self.sm.current_state_value)
first_result = self._sentinel
try:
took_events = True
Expand All @@ -82,6 +86,7 @@ def processing_loop(self): # noqa: C901

enabled_transitions = self.select_transitions(internal_event)
if enabled_transitions:
logger.debug("Eventless/internal queue: %s", enabled_transitions)
took_events = True
self.microstep(list(enabled_transitions), internal_event)

Expand All @@ -102,6 +107,7 @@ def processing_loop(self): # noqa: C901
while not self.external_queue.is_empty():
took_events = True
external_event = self.external_queue.pop()
logger.debug("External event: %s", external_event)
current_time = time()
if external_event.execution_time > current_time:
self.put(external_event)
Expand All @@ -122,6 +128,7 @@ def processing_loop(self): # noqa: C901
# self.send(inv.id, external_event)

enabled_transitions = self.select_transitions(external_event)
logger.debug("Enabled transitions: %s", enabled_transitions)
if enabled_transitions:
try:
result = self.microstep(list(enabled_transitions), external_event)
Expand All @@ -136,7 +143,7 @@ def processing_loop(self): # noqa: C901

else:
if not self.sm.allow_event_without_transition:
raise TransitionNotAllowed(external_event.event, self.sm.current_state)
raise TransitionNotAllowed(external_event.event, self.sm.configuration)

finally:
self._processing.release()
Expand Down
10 changes: 6 additions & 4 deletions statemachine/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING
from typing import MutableSet

from .i18n import _

Expand Down Expand Up @@ -30,12 +31,13 @@ class AttrNotFound(InvalidDefinition):


class TransitionNotAllowed(StateMachineError):
"Raised when there's no transition that can run from the current :ref:`state`."
"Raised when there's no transition that can run from the current :ref:`configuration`."

def __init__(self, event: "Event | None", state: "State"):
def __init__(self, event: "Event | None", configuration: MutableSet["State"]):
self.event = event
self.state = state
self.configuration = configuration
name = ", ".join([s.name for s in configuration])
msg = _("Can't {} when in {}.").format(
self.event and self.event.name or "transition", self.state.name
self.event and self.event.name or "transition", name
)
super().__init__(msg)
4 changes: 2 additions & 2 deletions statemachine/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
if not cls.states:
return

cls._initials_by_document_order(cls.states, parent=None)
cls._initials_by_document_order(list(cls.states), parent=None)

initials = [s for s in cls.states if s.initial]
parallels = [s.id for s in cls.states if s.parallel]
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(

def __getattr__(self, attribute: str) -> Any: ...

def _initials_by_document_order(cls, states, parent: "State | None" = None):
def _initials_by_document_order(cls, states: List[State], parent: "State | None" = None):
"""Set initial state by document order if no explicit initial state is set"""
initial: "State | None" = None
for s in states:
Expand Down
1 change: 1 addition & 0 deletions statemachine/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def create_machine_class_from_definition(
transition = source.to(
target,
event=event_name,
internal=transition_data.get("internal"),
initial=transition_data.get("initial"),
cond=transition_data.get("cond"),
unless=transition_data.get("unless"),
Expand Down
6 changes: 5 additions & 1 deletion statemachine/io/scxml/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,10 @@ def __call__(self, *args, **kwargs):
kwargs["_ioprocessors"] = self.processor.wrap(**kwargs)

try:
return _eval(self.action, **kwargs)
result = _eval(self.action, **kwargs)
logger.debug("Cond %s -> %s", self.action, result)
return result

except Exception as e:
machine.send("error.execution", error=e, internal=True)
return False
Expand Down Expand Up @@ -238,6 +241,7 @@ def __call__(self, *args, **kwargs):
f"got: {self.action.location}"
)
setattr(obj, attr, value)
logger.debug(f"Assign: {self.action.location} = {value!r}")


class Log(CallableAction):
Expand Down
Loading

0 comments on commit 8030a84

Please sign in to comment.