Skip to content

Commit 1f07a02

Browse files
committed
chore: move namespace variable to typing.py
1 parent 2edc2f2 commit 1f07a02

2 files changed

Lines changed: 8 additions & 8 deletions

File tree

src/optimagic/parameters/tree_registry.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from optree.pytree import PyTreeSpec
1212
from pybaum import get_registry as get_pybaum_registry
1313

14+
from optimagic.typing import extended_namespace
15+
1416

1517
def get_registry(extended=False, data_col="value"):
1618
"""Return pytree registry.
@@ -94,14 +96,11 @@ def _index_element_to_string(element):
9496
return res_string
9597

9698

97-
extended = "extended"
98-
99-
10099
def tree_flatten(tree, is_leaf=None, registry=None):
101100
if isinstance(tree, dict):
102101
tree = OrderedDict(tree)
103102
return optree.tree_flatten(
104-
tree, is_leaf=is_leaf, namespace=extended if registry else ""
103+
tree, is_leaf=is_leaf, namespace=extended_namespace if registry else ""
105104
)
106105

107106

@@ -118,7 +117,7 @@ def tree_unflatten(treedef, leaves, is_leaf=None, registry=None):
118117

119118
def tree_map(func, tree, is_leaf=None, registry=None):
120119
return optree.tree_map(
121-
func, tree, is_leaf=is_leaf, namespace=extended if registry else ""
120+
func, tree, is_leaf=is_leaf, namespace=extended_namespace if registry else ""
122121
)
123122

124123

@@ -174,7 +173,7 @@ def _unflatten_df_optree(aux_data, leaves):
174173
pd.DataFrame,
175174
_flatten_df_optree,
176175
_unflatten_df_optree,
177-
namespace=extended,
176+
namespace=extended_namespace,
178177
)
179178

180179
optree.register_pytree_node(
@@ -185,12 +184,12 @@ def _unflatten_df_optree(aux_data, leaves):
185184
list(sr.index.map(_index_element_to_string)),
186185
),
187186
lambda aux_data, leaves: pd.Series(leaves, **aux_data),
188-
namespace=extended,
187+
namespace=extended_namespace,
189188
)
190189

191190
optree.register_pytree_node(
192191
np.ndarray,
193192
lambda arr: (arr.flatten().tolist(), arr.shape, _array_element_names(arr)),
194193
lambda aux_data, leaves: np.array(leaves).reshape(aux_data),
195-
namespace=extended,
194+
namespace=extended_namespace,
196195
)

src/optimagic/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Scalar = Any
2323

2424
T = TypeVar("T")
25+
extended_namespace = "extended_namespace"
2526

2627

2728
class AggregationLevel(Enum):

0 commit comments

Comments
 (0)