Skip to content

Commit f55eb5a

Browse files
authored
fix: Fix passing conditions callbacks as references. (#431)
1 parent d195b4a commit f55eb5a

10 files changed

Lines changed: 220 additions & 64 deletions

File tree

statemachine/dispatcher.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,7 @@ def from_obj(cls, obj, skip_attrs=None):
1919
if isinstance(obj, ObjectConfig):
2020
return obj
2121
else:
22-
return cls(obj, set(skip_attrs) if skip_attrs else set(), str(id(obj)))
23-
24-
def getattr(self, attr):
25-
if attr in self.skip_attrs:
26-
return
27-
return getattr(self.obj, attr, None)
22+
return cls(obj, skip_attrs or set(), str(id(obj)))
2823

2924

3025
class WrapSearchResult:
@@ -88,29 +83,74 @@ def wrapper(*args, **kwargs):
8883
return wrapper
8984

9085

91-
def search_callable(attr, *configs) -> Generator[WrapSearchResult, None, None]:
92-
if callable(attr) or isinstance(attr, property):
93-
yield CallableSearchResult(attr, attr, None)
94-
else:
95-
for config in configs:
96-
func = config.getattr(attr)
97-
if func is None:
98-
continue
99-
if not callable(func):
100-
yield AttributeCallableSearchResult(attr, config.obj, config.resolver_id)
86+
def _search_callable_attr_is_property(
87+
attr, configs: tuple[ObjectConfig]
88+
) -> "WrapSearchResult | None":
89+
# if the attr is a property, we'll try to find the object that has the
90+
# property on the configs
91+
attr_name = attr.fget.__name__
92+
for obj, _skip_attrs, resolver_id in configs:
93+
func = getattr(type(obj), attr_name, None)
94+
if func is not None and func is attr:
95+
return AttributeCallableSearchResult(attr_name, obj, resolver_id)
96+
return None
10197

102-
if getattr(func, "_is_sm_event", False):
103-
yield EventSearchResult(attr, func, config.resolver_id)
10498

105-
yield CallableSearchResult(attr, func, config.resolver_id)
99+
def _search_callable_attr_is_callable(attr, configs: tuple[ObjectConfig]) -> WrapSearchResult:
100+
# if the attr is an unbounded method, we'll try to find the bounded method
101+
# on the configs
102+
if not hasattr(attr, "__self__"):
103+
for obj, _skip_attrs, resolver_id in configs:
104+
func = getattr(obj, attr.__name__, None)
105+
if func is not None and func.__func__ is attr:
106+
return CallableSearchResult(attr.__name__, func, resolver_id)
106107

108+
return CallableSearchResult(attr, attr, None)
109+
110+
111+
def _search_callable_in_configs(
112+
attr, configs: tuple[ObjectConfig]
113+
) -> Generator[WrapSearchResult, None, None]:
114+
for obj, skip_attrs, resolver_id in configs:
115+
if attr in skip_attrs:
116+
continue
117+
118+
if not hasattr(obj, attr):
119+
continue
120+
121+
func = getattr(obj, attr)
122+
if not callable(func):
123+
yield AttributeCallableSearchResult(attr, obj, resolver_id)
124+
125+
if getattr(func, "_is_sm_event", False):
126+
yield EventSearchResult(attr, func, resolver_id)
127+
128+
yield CallableSearchResult(attr, func, resolver_id)
107129

108-
def resolver_factory(*objects):
109-
"""Factory that returns a configured resolver."""
110130

111-
objects = [ObjectConfig.from_obj(obj) for obj in objects]
131+
def search_callable(attr, configs: tuple) -> Generator[WrapSearchResult, None, None]: # noqa: C901
132+
if isinstance(attr, property):
133+
result = _search_callable_attr_is_property(attr, configs)
134+
if result is not None:
135+
yield result
136+
return
137+
138+
if callable(attr):
139+
yield _search_callable_attr_is_callable(attr, configs)
140+
return
141+
142+
yield from _search_callable_in_configs(attr, configs)
143+
144+
145+
def resolver_factory(objects: tuple[ObjectConfig]):
146+
"""Factory that returns a configured resolver."""
112147

113148
def wrapper(attr) -> Generator[WrapSearchResult, None, None]:
114-
yield from search_callable(attr, *objects)
149+
yield from search_callable(attr, objects)
115150

116151
return wrapper
152+
153+
154+
def resolver_factory_from_objects(*objects: tuple[Any]):
155+
configs = tuple(ObjectConfig.from_obj(o) for o in objects)
156+
return resolver_factory(configs)

statemachine/signature.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
def _make_key(method):
1111
method = method.func if isinstance(method, partial) else method
12+
method = method.fget if isinstance(method, property) else method
1213
if isinstance(method, MethodType):
1314
return hash(
1415
(

statemachine/state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ def __repr__(self):
139139
f"initial={self.initial!r}, final={self.final!r})"
140140
)
141141

142+
def __str__(self):
143+
return self.name
144+
142145
def __get__(self, machine, owner):
143146
if machine is None:
144147
return self

statemachine/statemachine.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def __init__(
8181
if self._abstract:
8282
raise InvalidDefinition(_("There are no states or transitions."))
8383

84-
self._initial_transition = Transition(None, self._get_initial_state(), event="__initial__")
8584
self._setup()
8685
self._activate_initial_state()
8786

@@ -125,16 +124,17 @@ def _activate_initial_state(self):
125124
if self.current_state_value is None:
126125
# send an one-time event `__initial__` to enter the current state.
127126
# current_state = self.current_state
128-
self._initial_transition.before.clear()
129-
self._initial_transition.on.clear()
130-
self._initial_transition.after.clear()
127+
initial_transition = Transition(None, self._get_initial_state(), event="__initial__")
128+
initial_transition.before.clear()
129+
initial_transition.on.clear()
130+
initial_transition.after.clear()
131131

132132
event_data = EventData(
133133
trigger_data=TriggerData(
134134
machine=self,
135-
event=self._initial_transition.event,
135+
event=initial_transition.event,
136136
),
137-
transition=self._initial_transition,
137+
transition=initial_transition,
138138
)
139139
self._activate(event_data)
140140

@@ -157,29 +157,31 @@ def _iterate_states_and_transitions(self):
157157
yield state
158158
yield from state.transitions
159159

160+
def _build_register(self, observers):
161+
return partial(self._callbacks_registry.register, resolver=resolver_factory(observers))
162+
160163
def _setup(self):
161-
machine = ObjectConfig.from_obj(self, skip_attrs=self._get_protected_attrs())
162-
model = ObjectConfig.from_obj(self.model, skip_attrs={self.state_field})
163-
add_observer_visitor = self._build_observers_visitor(machine, model)
164+
register = self._build_register(
165+
(
166+
ObjectConfig.from_obj(self, skip_attrs=self._get_protected_attrs()),
167+
ObjectConfig.from_obj(self.model, skip_attrs={self.state_field}),
168+
)
169+
)
164170
check_callbacks = self._callbacks_registry.check
165171

166172
for visited in self._iterate_states_and_transitions():
167173
visited._setup()
168174

169175
for visited in self._iterate_states_and_transitions():
170-
add_observer_visitor(visited)
176+
visited._add_observer(register)
171177

172178
for visited in self._iterate_states_and_transitions():
173-
visited._check_callbacks(check_callbacks)
174-
175-
def _build_observers_visitor(self, *observers):
176-
resolver = resolver_factory(*observers)
177-
_register = partial(self._callbacks_registry.register, resolver=resolver)
178-
179-
def add_observer_visitor(visited):
180-
visited._add_observer(_register)
181-
182-
return add_observer_visitor
179+
try:
180+
visited._check_callbacks(check_callbacks)
181+
except Exception as err:
182+
raise InvalidDefinition(
183+
f"Error on {visited!s} when resolving callbacks: {err}"
184+
) from err
183185

184186
def add_observer(self, *observers):
185187
"""Add an observer.
@@ -192,10 +194,12 @@ def add_observer(self, *observers):
192194
:ref:`observers`.
193195
"""
194196
self._observers.update({o: None for o in observers})
197+
observers = tuple(ObjectConfig.from_obj(o) for o in observers)
195198

196-
add_observer_visitor = self._build_observers_visitor(*observers)
199+
register = self._build_register(observers)
197200
for visited in self._iterate_states_and_transitions():
198-
add_observer_visitor(visited)
201+
visited._add_observer(register)
202+
199203
return self
200204

201205
def _repr_html_(self):

statemachine/transition.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ def __repr__(self):
7373
f"internal={self.internal!r})"
7474
)
7575

76+
def __str__(self):
77+
return f"transition {self.event!s} from {self.source!s} to {self.target!s}"
78+
7679
def _setup(self):
7780
before = self.before.add
7881
on = self.on.add

tests/examples/order_control_rich_model_machine.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from statemachine import State
1010
from statemachine import StateMachine
11-
from statemachine.exceptions import AttrNotFound
11+
from statemachine.exceptions import InvalidDefinition
1212

1313

1414
class Order:
@@ -59,11 +59,16 @@ class OrderControl(StateMachine):
5959

6060
try:
6161
control = OrderControl()
62-
except AttrNotFound as e:
62+
except InvalidDefinition as e:
6363
assert ( # noqa: PT017
64-
str(e) == "Did not found name 'payment_received' from model or statemachine"
64+
str(e)
65+
== (
66+
"Error on transition process_order from Processing to Shipping when resolving "
67+
"callbacks: Did not found name 'payment_received' from model or statemachine"
68+
)
6569
)
6670

71+
6772
# %%
6873
# Now initializing with a proper ``order`` instance.
6974

tests/test_callbacks.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from statemachine.callbacks import CallbackMetaList
1010
from statemachine.callbacks import CallbacksExecutor
1111
from statemachine.callbacks import CallbacksRegistry
12-
from statemachine.dispatcher import resolver_factory
12+
from statemachine.dispatcher import resolver_factory_from_objects
1313
from statemachine.exceptions import InvalidDefinition
1414

1515

@@ -23,7 +23,9 @@ def __init__(self):
2323
["life_meaning", "name", "a_method"],
2424
)
2525
self.registry = CallbacksRegistry()
26-
self.executor = self.registry.register(self.callbacks, resolver=resolver_factory(self))
26+
self.executor = self.registry.register(
27+
self.callbacks, resolver=resolver_factory_from_objects(self)
28+
)
2729

2830
@property
2931
def life_meaning(self):
@@ -49,7 +51,7 @@ def do_something(self, *args, **kwargs):
4951
obj = MyObject()
5052

5153
meta_list.add(obj.do_something)
52-
executor.add(meta_list, resolver_factory(obj))
54+
executor.add(meta_list, resolver_factory_from_objects(obj))
5355

5456
executor.call(1, 2, 3, a="x", b="y")
5557

@@ -80,7 +82,7 @@ def last_one(self, *args, **kwargs):
8082
callbacks.add("my_method").add("other_method")
8183
callbacks.add("last_one")
8284

83-
registry.register(callbacks, resolver_factory(obj))
85+
registry.register(callbacks, resolver_factory_from_objects(obj))
8486

8587
registry[callbacks].call(1, 2, 3, a="x", b="y")
8688

@@ -111,7 +113,7 @@ def test_raise_error_if_didnt_found_attr(self, suppress_errors):
111113
callbacks = CallbackMetaList()
112114
registry = CallbacksRegistry()
113115

114-
register = partial(registry.register, resolver=resolver_factory(self))
116+
register = partial(registry.register, resolver=resolver_factory_from_objects(self))
115117

116118
callbacks.add(
117119
"this_does_no_exist",
@@ -139,7 +141,7 @@ def func3():
139141
return {"key": "value"}
140142

141143
callbacks.add([func1, func2, func3])
142-
registry.register(callbacks, resolver_factory(object()))
144+
registry.register(callbacks, resolver_factory_from_objects(object()))
143145

144146
results = registry[callbacks].call(1, 2, 3, a="x", b="y")
145147

@@ -170,7 +172,7 @@ def hero_lowercase(hero):
170172
def race_uppercase(race):
171173
return race.upper()
172174

173-
x.registry.register(x.callbacks, resolver=resolver_factory(x))
175+
x.registry.register(x.callbacks, resolver=resolver_factory_from_objects(x))
174176

175177
assert x.executor.call(hero="Gandalf", race="Maia") == [
176178
42,

tests/test_dispatcher.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33
from statemachine.dispatcher import ObjectConfig
4-
from statemachine.dispatcher import resolver_factory
4+
from statemachine.dispatcher import resolver_factory_from_objects
55
from statemachine.dispatcher import search_callable
66

77

@@ -46,28 +46,28 @@ def kwargs(self, request):
4646
def test_return_same_object_if_already_a_callable(self):
4747
model = Person("Frodo", "Bolseiro")
4848
expected = model.get_full_name
49-
actual = next(search_callable(expected)).wrap()
49+
actual = next(search_callable(expected, [])).wrap()
5050
assert actual.__name__ == expected.__name__
5151
assert actual.__doc__ == expected.__doc__
5252

5353
def test_retrieve_a_method_from_its_name(self, args, kwargs):
5454
model = Person("Frodo", "Bolseiro")
5555
expected = model.get_full_name
56-
method = next(search_callable("get_full_name", ObjectConfig.from_obj(model))).wrap()
56+
method = next(search_callable("get_full_name", [ObjectConfig.from_obj(model)])).wrap()
5757

5858
assert method.__name__ == expected.__name__
5959
assert method.__doc__ == expected.__doc__
6060
assert method(*args, **kwargs) == "Frodo Bolseiro"
6161

6262
def test_retrieve_a_callable_from_a_property_name(self, args, kwargs):
6363
model = Person("Frodo", "Bolseiro")
64-
method = next(search_callable("first_name", ObjectConfig.from_obj(model))).wrap()
64+
method = next(search_callable("first_name", [ObjectConfig.from_obj(model)])).wrap()
6565

6666
assert method(*args, **kwargs) == "Frodo"
6767

6868
def test_retrieve_callable_from_a_property_name_that_should_keep_reference(self, args, kwargs):
6969
model = Person("Frodo", "Bolseiro")
70-
method = next(search_callable("first_name", ObjectConfig.from_obj(model))).wrap()
70+
method = next(search_callable("first_name", [ObjectConfig.from_obj(model)])).wrap()
7171

7272
model.first_name = "Bilbo"
7373

@@ -88,7 +88,7 @@ def test_should_chain_resolutions(self, attr, expected_value):
8888
person = Person("Frodo", "Bolseiro", "cpf")
8989
org = Organization("The Lord fo the Rings", "cnpj")
9090

91-
resolver = resolver_factory(org, person)
91+
resolver = resolver_factory_from_objects(org, person)
9292
resolved_method = next(resolver(attr)).wrap()
9393
assert resolved_method() == expected_value
9494

@@ -107,6 +107,6 @@ def test_should_ignore_list_of_attrs(self, attr, expected_value):
107107

108108
org_config = ObjectConfig.from_obj(org, {"get_full_name"})
109109

110-
resolver = resolver_factory(org_config, person)
110+
resolver = resolver_factory_from_objects(org_config, person)
111111
resolved_method = next(resolver(attr)).wrap()
112112
assert resolved_method() == expected_value

tests/test_transition_list.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from statemachine import State
44
from statemachine.callbacks import CallbacksRegistry
5-
from statemachine.dispatcher import resolver_factory
5+
from statemachine.dispatcher import resolver_factory_from_objects
66

77

88
def test_transition_list_or_operator():
@@ -58,6 +58,6 @@ def my_callback():
5858
transition = s1.transitions[0]
5959
callback_list = getattr(transition, list_attr_name)
6060

61-
registry.register(callback_list, resolver_factory(object()))
61+
registry.register(callback_list, resolver_factory_from_objects(object()))
6262

6363
assert registry[callback_list].call() == [expected_value]

0 commit comments

Comments
 (0)