Skip to content

Commit

Permalink
refac: Improved isolation of components; caching results of built-in …
Browse files Browse the repository at this point in the history
…iscoroutinefunction (#493)
  • Loading branch information
fgmacedo authored Nov 12, 2024
1 parent 5528c3e commit 4e29771
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 63 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13.0"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v3
Expand All @@ -41,7 +41,7 @@ jobs:
# run ruff
#----------------------------------------------
- name: Linter with ruff
if: matrix.python-version == 3.12
if: matrix.python-version == 3.13
run: |
uv run ruff check .
uv run ruff format --check .
Expand All @@ -57,7 +57,7 @@ jobs:
#----------------------------------------------
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
if: matrix.python-version == 3.12
if: matrix.python-version == 3.13
with:
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
directory: .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.12"]
python-version: ["3.13"]

# Specifying a GitHub environment is optional, but strongly encouraged
environment: release
Expand Down
24 changes: 22 additions & 2 deletions statemachine/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from enum import IntFlag
from enum import auto
from inspect import isawaitable
from inspect import iscoroutinefunction
from typing import TYPE_CHECKING
from typing import Callable
from typing import Dict
Expand Down Expand Up @@ -233,7 +232,7 @@ def __init__(
unique_key: str,
) -> None:
self._callback = callback
self._iscoro = iscoroutinefunction(callback)
self._iscoro = getattr(callback, "is_coroutine", False)
self.condition = condition
self.meta = meta
self.unique_key = unique_key
Expand Down Expand Up @@ -361,3 +360,24 @@ def async_or_sync(self):
self.has_async_callbacks = any(
callback._iscoro for executor in self._registry.values() for callback in executor
)

def call(self, key: str, *args, **kwargs):
if key not in self._registry:
return []
return self._registry[key].call(*args, **kwargs)

def async_call(self, key: str, *args, **kwargs):
return self._registry[key].async_call(*args, **kwargs)

def all(self, key: str, *args, **kwargs):
if key not in self._registry:
return True
return self._registry[key].all(*args, **kwargs)

def async_all(self, key: str, *args, **kwargs):
return self._registry[key].async_all(*args, **kwargs)

def str(self, key: str) -> str:
if key not in self._registry:
return ""
return str(self._registry[key])
8 changes: 4 additions & 4 deletions statemachine/contrib/diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DotGraphMachine:
transition_font_size = "9"
"""Transition font size in points"""

def __init__(self, machine):
def __init__(self, machine: StateMachine):
self.machine = machine

def _get_graph(self):
Expand Down Expand Up @@ -69,11 +69,11 @@ def _initial_edge(self):
def _actions_getter(self):
if isinstance(self.machine, StateMachine):

def getter(grouper):
return self.machine._get_callbacks(grouper.key)
def getter(grouper) -> str:
return self.machine._callbacks_registry.str(grouper.key)
else:

def getter(grouper):
def getter(grouper) -> str:
all_names = set(dir(self.machine))
return ", ".join(
str(c) for c in grouper if not c.is_convention or c.func in all_names
Expand Down
25 changes: 21 additions & 4 deletions statemachine/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,27 @@ def search_name(self, name):


def callable_method(a_callable) -> Callable:
method = SignatureAdapter.wrap(a_callable)
method.__name__ = a_callable.__name__
method.__doc__ = a_callable.__doc__
return method
sig = SignatureAdapter.from_callable(a_callable)
sig_bind_expected = sig.bind_expected

metadata_to_copy = a_callable.func if isinstance(a_callable, partial) else a_callable

if sig.is_coroutine:

async def signature_adapter(*args: Any, **kwargs: Any) -> Any:
ba = sig_bind_expected(*args, **kwargs)
return await a_callable(*ba.args, **ba.kwargs)
else:

def signature_adapter(*args: Any, **kwargs: Any) -> Any: # type: ignore[misc]
ba = sig_bind_expected(*args, **kwargs)
return a_callable(*ba.args, **ba.kwargs)

signature_adapter.__name__ = metadata_to_copy.__name__
signature_adapter.__doc__ = metadata_to_copy.__doc__
signature_adapter.is_coroutine = sig.is_coroutine # type: ignore[attr-defined]

return signature_adapter


def attr_method(attribute, obj) -> Callable:
Expand Down
20 changes: 13 additions & 7 deletions statemachine/engines/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,12 @@ async def _trigger(self, trigger_data: TriggerData):

event_data = EventData(trigger_data=trigger_data, transition=transition)
args, kwargs = event_data.args, event_data.extended_kwargs
await self.sm._get_callbacks(transition.validators.key).async_call(*args, **kwargs)
if not await self.sm._get_callbacks(transition.cond.key).async_all(*args, **kwargs):
await self.sm._callbacks_registry.async_call(
transition.validators.key, *args, **kwargs
)
if not await self.sm._callbacks_registry.async_all(
transition.cond.key, *args, **kwargs
):
continue

result = await self._activate(event_data)
Expand All @@ -115,19 +119,21 @@ async def _activate(self, event_data: EventData):
source = event_data.state
target = transition.target

result = await self.sm._get_callbacks(transition.before.key).async_call(*args, **kwargs)
result = await self.sm._callbacks_registry.async_call(
transition.before.key, *args, **kwargs
)
if source is not None and not transition.internal:
await self.sm._get_callbacks(source.exit.key).async_call(*args, **kwargs)
await self.sm._callbacks_registry.async_call(source.exit.key, *args, **kwargs)

result += await self.sm._get_callbacks(transition.on.key).async_call(*args, **kwargs)
result += await self.sm._callbacks_registry.async_call(transition.on.key, *args, **kwargs)

self.sm.current_state = target
event_data.state = target
kwargs["state"] = target

if not transition.internal:
await self.sm._get_callbacks(target.enter.key).async_call(*args, **kwargs)
await self.sm._get_callbacks(transition.after.key).async_call(*args, **kwargs)
await self.sm._callbacks_registry.async_call(target.enter.key, *args, **kwargs)
await self.sm._callbacks_registry.async_call(transition.after.key, *args, **kwargs)

if len(result) == 0:
result = None
Expand Down
14 changes: 7 additions & 7 deletions statemachine/engines/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def _trigger(self, trigger_data: TriggerData):

event_data = EventData(trigger_data=trigger_data, transition=transition)
args, kwargs = event_data.args, event_data.extended_kwargs
self.sm._get_callbacks(transition.validators.key).call(*args, **kwargs)
if not self.sm._get_callbacks(transition.cond.key).all(*args, **kwargs):
self.sm._callbacks_registry.call(transition.validators.key, *args, **kwargs)
if not self.sm._callbacks_registry.all(transition.cond.key, *args, **kwargs):
continue

result = self._activate(event_data)
Expand All @@ -118,19 +118,19 @@ def _activate(self, event_data: EventData):
source = event_data.state
target = transition.target

result = self.sm._get_callbacks(transition.before.key).call(*args, **kwargs)
result = self.sm._callbacks_registry.call(transition.before.key, *args, **kwargs)
if source is not None and not transition.internal:
self.sm._get_callbacks(source.exit.key).call(*args, **kwargs)
self.sm._callbacks_registry.call(source.exit.key, *args, **kwargs)

result += self.sm._get_callbacks(transition.on.key).call(*args, **kwargs)
result += self.sm._callbacks_registry.call(transition.on.key, *args, **kwargs)

self.sm.current_state = target
event_data.state = target
kwargs["state"] = target

if not transition.internal:
self.sm._get_callbacks(target.enter.key).call(*args, **kwargs)
self.sm._get_callbacks(transition.after.key).call(*args, **kwargs)
self.sm._callbacks_registry.call(target.enter.key, *args, **kwargs)
self.sm._callbacks_registry.call(transition.after.key, *args, **kwargs)

if len(result) == 0:
result = None
Expand Down
33 changes: 7 additions & 26 deletions statemachine/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from itertools import chain
from types import MethodType
from typing import Any
from typing import Callable


def _make_key(method):
Expand Down Expand Up @@ -44,40 +43,22 @@ def cached_function(cls, method):


class SignatureAdapter(Signature):
@classmethod
def wrap(cls, method) -> Callable:
"""Build a wrapper that adapts the received arguments to the inner ``method`` signature"""

sig = cls.from_callable(method)
sig_bind_expected = sig.bind_expected

metadata_to_copy = method.func if isinstance(method, partial) else method

if iscoroutinefunction(method):

async def signature_adapter(*args: Any, **kwargs: Any) -> Any:
ba = sig_bind_expected(*args, **kwargs)
return await method(*ba.args, **ba.kwargs)
else:

def signature_adapter(*args: Any, **kwargs: Any) -> Any: # type: ignore[misc]
ba = sig_bind_expected(*args, **kwargs)
return method(*ba.args, **ba.kwargs)

signature_adapter.__name__ = metadata_to_copy.__name__

return signature_adapter
is_coroutine: bool = False

@classmethod
@signature_cache
def from_callable(cls, method):
if hasattr(method, "__signature__"):
sig = method.__signature__
return SignatureAdapter(
adapter = SignatureAdapter(
sig.parameters.values(),
return_annotation=sig.return_annotation,
)
return super().from_callable(method)
else:
adapter = super().from_callable(method)

adapter.is_coroutine = iscoroutinefunction(method)
return adapter

def bind_expected(self, *args: Any, **kwargs: Any) -> BoundArguments: # noqa: C901
"""Get a BoundArguments object, that maps the passed `args`
Expand Down
4 changes: 4 additions & 0 deletions statemachine/spec_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def build_expression(node, variable_hook, operator_mapping):

def parse_boolean_expr(expr, variable_hook, operator_mapping):
"""Parses the expression into an AST and build a custom expression tree"""
if expr.strip() == "":
raise SyntaxError("Empty expression")
if "!" not in expr and " " not in expr:
return variable_hook(expr)
expr = replace_operators(expr)
tree = ast.parse(expr, mode="eval")
return build_expression(tree.body, variable_hook, operator_mapping)
Expand Down
4 changes: 0 additions & 4 deletions statemachine/statemachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from .callbacks import SPECS_ALL
from .callbacks import SPECS_SAFE
from .callbacks import CallbacksExecutor
from .callbacks import CallbacksRegistry
from .callbacks import SpecReference
from .dispatcher import Listener
Expand Down Expand Up @@ -322,6 +321,3 @@ def send(self, event: str, *args, **kwargs):
if not isawaitable(result):
return result
return run_async_from_sync(result)

def _get_callbacks(self, key) -> CallbacksExecutor:
return self._callbacks_registry[key]
6 changes: 3 additions & 3 deletions tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from statemachine.signature import SignatureAdapter
from statemachine.dispatcher import callable_method


def single_positional_param(a):
Expand Down Expand Up @@ -147,7 +147,7 @@ class TestSignatureAdapter:
],
)
def test_wrap_fn_single_positional_parameter(self, func, args, kwargs, expected):
wrapped_func = SignatureAdapter.wrap(func)
wrapped_func = callable_method(func)
assert wrapped_func.__name__ == func.__name__

if inspect.isclass(expected) and issubclass(expected, Exception):
Expand All @@ -158,7 +158,7 @@ def test_wrap_fn_single_positional_parameter(self, func, args, kwargs, expected)

def test_support_for_partial(self):
part = partial(positional_and_kw_arguments, event="activated")
wrapped_func = SignatureAdapter.wrap(part)
wrapped_func = callable_method(part)

assert wrapped_func("A", "B") == ("A", "B", "activated")
assert wrapped_func.__name__ == positional_and_kw_arguments.__name__
4 changes: 2 additions & 2 deletions tests/test_signature_positional_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from statemachine.signature import SignatureAdapter
from statemachine.dispatcher import callable_method


class TestSignatureAdapter:
Expand All @@ -25,7 +25,7 @@ def func(pos_only, /, pos_or_kw_param, *, kw_only_param):
# https://peps.python.org/pep-0570/
return pos_only, pos_or_kw_param, kw_only_param

wrapped_func = SignatureAdapter.wrap(func)
wrapped_func = callable_method(func)

if inspect.isclass(expected) and issubclass(expected, Exception):
with pytest.raises(expected):
Expand Down

0 comments on commit 4e29771

Please sign in to comment.