1111from optree .pytree import PyTreeSpec
1212from pybaum import get_registry as get_pybaum_registry
1313
14+ from optimagic .typing import extended_namespace
15+
1416
1517def get_registry (extended = False , data_col = "value" ):
1618 """Return pytree registry.
@@ -94,14 +96,11 @@ def _index_element_to_string(element):
9496 return res_string
9597
9698
97- extended = "extended"
98-
99-
10099def tree_flatten (tree , is_leaf = None , registry = None ):
101100 if isinstance (tree , dict ):
102101 tree = OrderedDict (tree )
103102 return optree .tree_flatten (
104- tree , is_leaf = is_leaf , namespace = extended if registry else ""
103+ tree , is_leaf = is_leaf , namespace = extended_namespace if registry else ""
105104 )
106105
107106
@@ -118,7 +117,7 @@ def tree_unflatten(treedef, leaves, is_leaf=None, registry=None):
118117
119118def tree_map (func , tree , is_leaf = None , registry = None ):
120119 return optree .tree_map (
121- func , tree , is_leaf = is_leaf , namespace = extended if registry else ""
120+ func , tree , is_leaf = is_leaf , namespace = extended_namespace if registry else ""
122121 )
123122
124123
@@ -174,7 +173,7 @@ def _unflatten_df_optree(aux_data, leaves):
174173 pd .DataFrame ,
175174 _flatten_df_optree ,
176175 _unflatten_df_optree ,
177- namespace = extended ,
176+ namespace = extended_namespace ,
178177)
179178
180179optree .register_pytree_node (
@@ -185,12 +184,12 @@ def _unflatten_df_optree(aux_data, leaves):
185184 list (sr .index .map (_index_element_to_string )),
186185 ),
187186 lambda aux_data , leaves : pd .Series (leaves , ** aux_data ),
188- namespace = extended ,
187+ namespace = extended_namespace ,
189188)
190189
191190optree .register_pytree_node (
192191 np .ndarray ,
193192 lambda arr : (arr .flatten ().tolist (), arr .shape , _array_element_names (arr )),
194193 lambda aux_data , leaves : np .array (leaves ).reshape (aux_data ),
195- namespace = extended ,
194+ namespace = extended_namespace ,
196195)
0 commit comments