Skip to content

Commit

Permalink
feat: Conditionals with boolean algebra (#487)
Browse files Browse the repository at this point in the history
  • Loading branch information
fgmacedo authored Nov 2, 2024
1 parent ff14d62 commit 1275cd4
Show file tree
Hide file tree
Showing 7 changed files with 462 additions and 12 deletions.
20 changes: 20 additions & 0 deletions docs/guards.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,31 @@ unless
* Single condition: `unless="condition"`
* Multiple conditions: `unless=["condition1", "condition2"]`

Conditions also support [Boolean algebra](https://en.wikipedia.org/wiki/Boolean_algebra) expressions, allowing you to use compound logic within transition guards. You can use both standard Python logical operators (`not`, `and`, `or`) as well as classic Boolean algebra symbols:

- `!` for `not`
- `^` for `and`
- `v` for `or`

For example:

```python
start.to(end, cond="frodo_has_ring and gandalf_present or !sauron_alive")
```

Both formats can be used interchangeably, so `!sauron_alive` and `not sauron_alive` are equivalent.


```{seealso}
See {ref}`sphx_glr_auto_examples_air_conditioner_machine.py` for an example of
combining multiple transitions to the same event.
```

```{seealso}
See {ref}`sphx_glr_auto_examples_lor_machine.py` for an example of
using boolean algebra in conditions.
```

```{hint}
In Python, a boolean value is either `True` or `False`. However, there are also specific values that
are considered "**falsy**" and will evaluate as `False` when used in a boolean context.
Expand Down
72 changes: 62 additions & 10 deletions statemachine/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,30 @@
from enum import IntEnum
from enum import IntFlag
from enum import auto
from functools import partial
from functools import reduce
from inspect import isawaitable
from inspect import iscoroutinefunction
from typing import TYPE_CHECKING
from typing import Callable
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import List
from typing import Set
from typing import Type

from .exceptions import AttrNotFound
from .exceptions import InvalidDefinition
from .i18n import _
from .spec_parser import custom_and
from .spec_parser import operator_mapping
from .spec_parser import parse_boolean_expr
from .utils import ensure_iterable

if TYPE_CHECKING:
from statemachine.dispatcher import Listeners


class CallbackPriority(IntEnum):
GENERIC = 0
Expand Down Expand Up @@ -54,6 +65,17 @@ def allways_true(*args, **kwargs):
return True


def take_callback(name: str, resolver: "Listeners", not_found_handler: Callable) -> Callable:
callbacks = list(resolver.search_name(name))
if len(callbacks) == 0:
not_found_handler(name)
return allways_true
elif len(callbacks) == 1:
return callbacks[0]
else:
return reduce(custom_and, callbacks)


class CallbackSpec:
"""Specs about callbacks.
Expand Down Expand Up @@ -110,22 +132,46 @@ def _update_func(self, func: Callable, attr_name: str):
self.reference = SpecReference.CALLABLE
self.attr_name = attr_name

def build(self, resolver) -> Generator["CallbackWrapper", None, None]:
def _wrap(self, callback):
condition = self.cond if self.cond is not None else allways_true
return CallbackWrapper(
callback=callback,
condition=condition,
meta=self,
unique_key=callback.unique_key,
)

def build(self, resolver: "Listeners") -> Generator["CallbackWrapper", None, None]:
"""
Resolves the `func` into a usable callable.
Args:
resolver (callable): A method responsible to build and return a valid callable that
can receive arbitrary parameters like `*args, **kwargs`.
"""
for callback in resolver.search(self):
condition = self.cond if self.cond is not None else allways_true
yield CallbackWrapper(
callback=callback,
condition=condition,
meta=self,
unique_key=callback.unique_key,
if (
not self.is_convention
and self.group == CallbackGroup.COND
and self.reference == SpecReference.NAME
):
names_not_found: Set[str] = set()
take_callback_partial = partial(
take_callback, resolver=resolver, not_found_handler=names_not_found.add
)
try:
expression = parse_boolean_expr(self.func, take_callback_partial, operator_mapping)
except SyntaxError as err:
raise InvalidDefinition(
_("Failed to parse boolean expression '{}'").format(self.func)
) from err
if not expression or names_not_found:
self.names_not_found = names_not_found
return
yield self._wrap(expression)
return

for callback in resolver.search(self):
yield self._wrap(callback)


class SpecListGrouper:
Expand Down Expand Up @@ -292,15 +338,15 @@ def __repr__(self):
def __str__(self):
return ", ".join(str(c) for c in self)

def _add(self, spec: CallbackSpec, resolver: Callable):
def _add(self, spec: CallbackSpec, resolver: "Listeners"):
for callback in spec.build(resolver):
if callback.unique_key in self.items_already_seen:
continue

self.items_already_seen.add(callback.unique_key)
insort(self.items, callback)

def add(self, items: Iterable[CallbackSpec], resolver: Callable):
def add(self, items: Iterable[CallbackSpec], resolver: "Listeners"):
"""Validate configurations"""
for item in items:
self._add(item, resolver)
Expand Down Expand Up @@ -356,6 +402,12 @@ def check(self, specs: CallbackSpecList):
callback for callback in self[meta.group.build_key(specs)] if callback.meta == meta
):
continue
if hasattr(meta, "names_not_found"):
raise AttrNotFound(
_("Did not found name '{}' from model or statemachine").format(
", ".join(meta.names_not_found)
),
)
raise AttrNotFound(
_("Did not found name '{}' from model or statemachine").format(meta.func)
)
Expand Down
5 changes: 3 additions & 2 deletions statemachine/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def resolve(

def search(self, spec: "CallbackSpec") -> Generator["Callable", None, None]:
if spec.reference is SpecReference.NAME:
yield from self._search_name(spec.func)
yield from self.search_name(spec.func)
return
elif spec.reference is SpecReference.CALLABLE:
yield self._search_callable(spec)
Expand Down Expand Up @@ -111,7 +111,7 @@ def _search_callable(self, spec) -> "Callable":

return callable_method(spec.attr_name, spec.func, None)

def _search_name(self, name) -> Generator["Callable", None, None]:
def search_name(self, name) -> Generator["Callable", None, None]:
for config in self.items:
if name not in config.all_attrs:
continue
Expand Down Expand Up @@ -143,6 +143,7 @@ def method(*args, **kwargs):
return getter(obj)

method.unique_key = f"{attribute}@{resolver_id}" # type: ignore[attr-defined]
method.__name__ = attribute
return method


Expand Down
79 changes: 79 additions & 0 deletions statemachine/spec_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import ast
import re
from typing import Callable

replacements = {"!": "not ", "^": " and ", "v": " or "}

pattern = re.compile(r"\!|\^|\bv\b")


def replace_operators(expr: str) -> str:
# preprocess the expression adding support for classical logical operators
def match_func(match):
return replacements[match.group(0)]

return pattern.sub(match_func, expr)


def custom_not(predicate: Callable) -> Callable:
def decorated(*args, **kwargs) -> bool:
return not predicate(*args, **kwargs)

decorated.__name__ = f"not({predicate.__name__})"
unique_key = getattr(predicate, "unique_key", "")
decorated.unique_key = f"not({unique_key})" # type: ignore[attr-defined]
return decorated


def _unique_key(left, right, operator) -> str:
left_key = getattr(left, "unique_key", "")
right_key = getattr(right, "unique_key", "")
return f"{left_key} {operator} {right_key}"


def custom_and(left: Callable, right: Callable) -> Callable:
def decorated(*args, **kwargs) -> bool:
return left(*args, **kwargs) and right(*args, **kwargs) # type: ignore[no-any-return]

decorated.__name__ = f"({left.__name__} and {right.__name__})"
decorated.unique_key = _unique_key(left, right, "and") # type: ignore[attr-defined]
return decorated


def custom_or(left: Callable, right: Callable) -> Callable:
def decorated(*args, **kwargs) -> bool:
return left(*args, **kwargs) or right(*args, **kwargs) # type: ignore[no-any-return]

decorated.__name__ = f"({left.__name__} or {right.__name__})"
decorated.unique_key = _unique_key(left, right, "or") # type: ignore[attr-defined]
return decorated


def build_expression(node, variable_hook, operator_mapping):
if isinstance(node, ast.BoolOp):
# Handle `and` / `or` operations
operator_fn = operator_mapping[type(node.op)]
left_expr = build_expression(node.values[0], variable_hook, operator_mapping)
for right in node.values[1:]:
right_expr = build_expression(right, variable_hook, operator_mapping)
left_expr = operator_fn(left_expr, right_expr)
return left_expr
elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
# Handle `not` operation
operand_expr = build_expression(node.operand, variable_hook, operator_mapping)
return operator_mapping[type(node.op)](operand_expr)
elif isinstance(node, ast.Name):
# Handle variables by calling the variable_hook
return variable_hook(node.id)
else:
raise ValueError(f"Unsupported expression structure: {node.__class__.__name__}")


def parse_boolean_expr(expr, variable_hook, operator_mapping):
"""Parses the expression into an AST and build a custom expression tree"""
expr = replace_operators(expr)
tree = ast.parse(expr, mode="eval")
return build_expression(tree.body, variable_hook, operator_mapping)


operator_mapping = {ast.Or: custom_or, ast.And: custom_and, ast.Not: custom_not}
102 changes: 102 additions & 0 deletions tests/examples/lor_machine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
Lord of the Rings Quest - Boolean algebra
=========================================
Example that demonstrates the use of Boolean algebra in conditions.
"""

from statemachine import State
from statemachine import StateMachine
from statemachine.exceptions import TransitionNotAllowed


class LordOfTheRingsQuestStateMachine(StateMachine):
# Define the states
shire = State("In the Shire", initial=True)
bree = State("In Bree")
rivendell = State("At Rivendell")
moria = State("In Moria")
lothlorien = State("In Lothlorien")
mordor = State("In Mordor")
mount_doom = State("At Mount Doom", final=True)

# Define transitions with Boolean conditions
start_journey = shire.to(bree, cond="frodo_has_ring and !sauron_alive")
meet_elves = bree.to(rivendell, cond="gandalf_present and frodo_has_ring")
enter_moria = rivendell.to(moria, cond="orc_army_nearby or frodo_has_ring")
reach_lothlorien = moria.to(lothlorien, cond="!orc_army_nearby")
journey_to_mordor = lothlorien.to(mordor, cond="frodo_has_ring and sam_is_loyal")
destroy_ring = mordor.to(mount_doom, cond="frodo_has_ring and frodo_resists_ring")

# Conditions (attributes representing the state of conditions)
frodo_has_ring: bool = True
sauron_alive: bool = True # Initially, Sauron is alive
gandalf_present: bool = False # Gandalf is not present at the start
orc_army_nearby: bool = False
sam_is_loyal: bool = True
frodo_resists_ring: bool = False # Initially, Frodo is not resisting the ring


# %%
# Playing

quest = LordOfTheRingsQuestStateMachine()

# Track state changes
print(f"Current State: {quest.current_state.id}") # Should start at "shire"

# Step 1: Start the journey
quest.sauron_alive = False # Assume Sauron is no longer alive
try:
quest.start_journey()
print(f"Current State: {quest.current_state.id}") # Should be "bree"
except TransitionNotAllowed:
print("Unable to start journey: conditions not met.")

# Step 2: Meet the elves in Rivendell
quest.gandalf_present = True # Gandalf is now present
try:
quest.meet_elves()
print(f"Current State: {quest.current_state.id}") # Should be "rivendell"
except TransitionNotAllowed:
print("Unable to meet elves: conditions not met.")

# Step 3: Enter Moria
quest.orc_army_nearby = True # Orc army is nearby
try:
quest.enter_moria()
print(f"Current State: {quest.current_state.id}") # Should be "moria"
except TransitionNotAllowed:
print("Unable to enter Moria: conditions not met.")

# Step 4: Reach Lothlorien
quest.orc_army_nearby = False # Orcs are no longer nearby
try:
quest.reach_lothlorien()
print(f"Current State: {quest.current_state.id}") # Should be "lothlorien"
except TransitionNotAllowed:
print("Unable to reach Lothlorien: conditions not met.")

# Step 5: Journey to Mordor
try:
quest.journey_to_mordor()
print(f"Current State: {quest.current_state.id}") # Should be "mordor"
except TransitionNotAllowed:
print("Unable to journey to Mordor: conditions not met.")

# Step 6: Fight with Smeagol
try:
quest.destroy_ring()
print(f"Current State: {quest.current_state.id}") # Should be "mount_doom"
except TransitionNotAllowed:
print("Unable to destroy the ring: conditions not met.")


# Step 7: Destroy the ring at Mount Doom
quest.frodo_resists_ring = True # Frodo is now resisting the ring
try:
quest.destroy_ring()
print(f"Current State: {quest.current_state.id}") # Should be "mount_doom"
except TransitionNotAllowed:
print("Unable to destroy the ring: conditions not met.")
Loading

0 comments on commit 1275cd4

Please sign in to comment.