Skip to content

Commit 109623f

Browse files
committed
chore: use optree context manager for ordering dict
1 parent def8aa3 commit 109623f

1 file changed

Lines changed: 5 additions & 7 deletions

File tree

src/optimagic/parameters/tree_registry.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Wrapper around pybaum get_registry to tailor it to optimagic."""
22

33
import itertools
4-
from collections import OrderedDict
54
from functools import partial
65
from itertools import product
76

@@ -97,11 +96,10 @@ def _index_element_to_string(element):
9796

9897

9998
def tree_flatten(tree, is_leaf=None, registry=None):
100-
if isinstance(tree, dict):
101-
tree = OrderedDict(tree)
102-
return optree.tree_flatten(
103-
tree, is_leaf=is_leaf, namespace=extended_namespace if registry else ""
104-
)
99+
with optree.dict_insertion_ordered(True, namespace=extended_namespace):
100+
return optree.tree_flatten(
101+
tree, is_leaf=is_leaf, namespace=extended_namespace if registry else ""
102+
)
105103

106104

107105
def tree_just_flatten(tree, is_leaf=None, registry=None):
@@ -122,7 +120,7 @@ def tree_map(func, tree, is_leaf=None, registry=None):
122120

123121

124122
def leaf_names(tree, is_leaf=None, registry=None, separator="_"):
125-
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf, registry=registry)
123+
_, treespec = tree_flatten(tree, is_leaf=is_leaf, registry=registry)
126124
paths = treespec.paths()
127125
return [separator.join(str(p) for p in path) for path in paths]
128126

0 commit comments

Comments
 (0)