11"""Wrapper around pybaum get_registry to tailor it to optimagic."""
22
33import itertools
4- from collections import OrderedDict
54from functools import partial
65from itertools import product
76
@@ -97,11 +96,10 @@ def _index_element_to_string(element):
9796
9897
9998def tree_flatten (tree , is_leaf = None , registry = None ):
100- if isinstance (tree , dict ):
101- tree = OrderedDict (tree )
102- return optree .tree_flatten (
103- tree , is_leaf = is_leaf , namespace = extended_namespace if registry else ""
104- )
99+ with optree .dict_insertion_ordered (True , namespace = extended_namespace ):
100+ return optree .tree_flatten (
101+ tree , is_leaf = is_leaf , namespace = extended_namespace if registry else ""
102+ )
105103
106104
107105def tree_just_flatten (tree , is_leaf = None , registry = None ):
@@ -122,7 +120,7 @@ def tree_map(func, tree, is_leaf=None, registry=None):
122120
123121
124122def leaf_names (tree , is_leaf = None , registry = None , separator = "_" ):
125- leaves , treespec = tree_flatten (tree , is_leaf = is_leaf , registry = registry )
123+ _ , treespec = tree_flatten (tree , is_leaf = is_leaf , registry = registry )
126124 paths = treespec .paths ()
127125 return [separator .join (str (p ) for p in path ) for path in paths ]
128126
0 commit comments