Skip to content

Commit df06976

Browse files
authored
chore: StateMachine setup as observer sharing same code (#432)
1 parent f55eb5a commit df06976

1 file changed

Lines changed: 13 additions & 18 deletions

File tree

statemachine/statemachine.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -157,24 +157,18 @@ def _iterate_states_and_transitions(self):
157157
yield state
158158
yield from state.transitions
159159

160-
def _build_register(self, observers):
161-
return partial(self._callbacks_registry.register, resolver=resolver_factory(observers))
162-
163160
def _setup(self):
164-
register = self._build_register(
161+
for visited in self._iterate_states_and_transitions():
162+
visited._setup()
163+
164+
self._add_observer(
165165
(
166166
ObjectConfig.from_obj(self, skip_attrs=self._get_protected_attrs()),
167167
ObjectConfig.from_obj(self.model, skip_attrs={self.state_field}),
168168
)
169169
)
170-
check_callbacks = self._callbacks_registry.check
171-
172-
for visited in self._iterate_states_and_transitions():
173-
visited._setup()
174-
175-
for visited in self._iterate_states_and_transitions():
176-
visited._add_observer(register)
177170

171+
check_callbacks = self._callbacks_registry.check
178172
for visited in self._iterate_states_and_transitions():
179173
try:
180174
visited._check_callbacks(check_callbacks)
@@ -183,6 +177,13 @@ def _setup(self):
183177
f"Error on {visited!s} when resolving callbacks: {err}"
184178
) from err
185179

180+
def _add_observer(self, observers):
181+
register = partial(self._callbacks_registry.register, resolver=resolver_factory(observers))
182+
for visited in self._iterate_states_and_transitions():
183+
visited._add_observer(register)
184+
185+
return self
186+
186187
def add_observer(self, *observers):
187188
"""Add an observer.
188189
@@ -194,13 +195,7 @@ def add_observer(self, *observers):
194195
:ref:`observers`.
195196
"""
196197
self._observers.update({o: None for o in observers})
197-
observers = tuple(ObjectConfig.from_obj(o) for o in observers)
198-
199-
register = self._build_register(observers)
200-
for visited in self._iterate_states_and_transitions():
201-
visited._add_observer(register)
202-
203-
return self
198+
return self._add_observer(tuple(ObjectConfig.from_obj(o) for o in observers))
204199

205200
def _repr_html_(self):
206201
return f'<div class="statemachine">{self._repr_svg_()}</div>'

0 commit comments

Comments
 (0)