Skip to content

Commit 16f2a73

Browse files
authored
fix: Reconstructing callable references when using deepcopy over a SM. (#425)
1 parent f236cad commit 16f2a73

3 files changed

Lines changed: 52 additions & 16 deletions

File tree

statemachine/callbacks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ def register(self, callbacks: CallbackMetaList, resolver):
240240
executor_list.add(callbacks, resolver)
241241
return executor_list
242242

243+
def clear(self):
244+
self._registry.clear()
245+
243246
def __getitem__(self, callbacks: CallbackMetaList) -> CallbacksExecutor:
244247
return self._registry[callbacks]
245248

statemachine/statemachine.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import deque
2+
from copy import deepcopy
23
from functools import partial
34
from typing import TYPE_CHECKING
45
from typing import Any
@@ -78,9 +79,9 @@ def __init__(
7879
if self._abstract:
7980
raise InvalidDefinition(_("There are no states or transitions."))
8081

81-
initial_transition = Transition(None, self._get_initial_state(), event="__initial__")
82-
self._setup(initial_transition)
83-
self._activate_initial_state(initial_transition)
82+
self._initial_transition = Transition(None, self._get_initial_state(), event="__initial__")
83+
self._setup()
84+
self._activate_initial_state()
8485

8586
def __init_subclass__(cls, strict_states: bool = False):
8687
cls._strict_states = strict_states
@@ -98,27 +99,39 @@ def __repr__(self):
9899
f"current_state={current_state_id!r})"
99100
)
100101

102+
def __deepcopy__(self, memo):
103+
deepcopy_method = self.__deepcopy__
104+
self.__deepcopy__ = None
105+
try:
106+
cp = deepcopy(self, memo)
107+
finally:
108+
self.__deepcopy__ = deepcopy_method
109+
cp.__deepcopy__ = deepcopy_method
110+
cp._callbacks_registry.clear()
111+
cp._setup()
112+
return cp
113+
101114
def _get_initial_state(self):
102115
current_state_value = self.start_value if self.start_value else self.initial_state.value
103116
try:
104117
return self.states_map[current_state_value]
105118
except KeyError as err:
106119
raise InvalidStateValue(current_state_value) from err
107120

108-
def _activate_initial_state(self, initial_transition):
121+
def _activate_initial_state(self):
109122
if self.current_state_value is None:
110123
# send an one-time event `__initial__` to enter the current state.
111124
# current_state = self.current_state
112-
initial_transition.before.clear()
113-
initial_transition.on.clear()
114-
initial_transition.after.clear()
125+
self._initial_transition.before.clear()
126+
self._initial_transition.on.clear()
127+
self._initial_transition.after.clear()
115128

116129
event_data = EventData(
117130
trigger_data=TriggerData(
118131
machine=self,
119-
event=initial_transition.event,
132+
event=self._initial_transition.event,
120133
),
121-
transition=initial_transition,
134+
transition=self._initial_transition,
122135
)
123136
self._activate(event_data)
124137

@@ -142,12 +155,7 @@ def _visit_states_and_transitions(self, visitor):
142155
for transition in state.transitions:
143156
visitor(transition)
144157

145-
def _setup(self, initial_transition: Transition):
146-
"""
147-
Args:
148-
initial_transition: A special :ref:`transition` that triggers the enter on the
149-
`initial` :ref:`State`.
150-
"""
158+
def _setup(self):
151159
machine = ObjectConfig.from_obj(self, skip_attrs=self._get_protected_attrs())
152160
model = ObjectConfig.from_obj(self.model, skip_attrs={self.state_field})
153161
default_resolver = resolver_factory(machine, model)
@@ -162,7 +170,7 @@ def setup_visitor(visited):
162170

163171
self._visit_states_and_transitions(setup_visitor)
164172

165-
initial_transition._setup(register)
173+
self._initial_transition._setup(register)
166174

167175
def _build_observers_visitor(self, *observers):
168176
registry_callbacks = [

tests/test_deepcopy.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from copy import deepcopy
2+
3+
import pytest
4+
5+
from statemachine import State
6+
from statemachine import StateMachine
7+
from statemachine.exceptions import TransitionNotAllowed
8+
9+
10+
def test_deepcopy():
11+
class MySM(StateMachine):
12+
draft = State("Draft", initial=True, value="draft")
13+
published = State("Published", value="published")
14+
15+
publish = draft.to(published, cond="let_me_be_visible")
16+
17+
class MyModel:
18+
let_me_be_visible = False
19+
20+
sm = MySM(MyModel())
21+
22+
sm2 = deepcopy(sm)
23+
24+
with pytest.raises(TransitionNotAllowed):
25+
sm2.send("publish")

0 commit comments

Comments
 (0)