Skip to content

Commit b1371b2

Browse files
committed
fix: add default namespace for dict insertion ordering
1 parent 8257c09 commit b1371b2

3 files changed

Lines changed: 48 additions & 62 deletions

File tree

src/optimagic/parameters/tree_registry.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pandas as pd
1010
from optree.pytree import PyTreeSpec
1111

12-
from optimagic.typing import OPTREE_NAMESPACES
12+
from optimagic.typing import DEFAULT_NAMESPACE, OPTREE_NAMESPACES
1313

1414
try:
1515
import jax.numpy as jnp # type: ignore[import-not-found]
@@ -21,72 +21,61 @@
2121
_are_namespaces_registered = False
2222

2323

24-
def tree_flatten(tree, is_leaf=None, namespace=""):
24+
def tree_flatten(tree, is_leaf=None, namespace=DEFAULT_NAMESPACE):
2525
"""Flatten a pytree."""
2626
_register_namespaces()
2727
_check_namespace(namespace)
28+
with optree.dict_insertion_ordered(True, namespace=namespace):
29+
return optree.tree_flatten(tree, is_leaf=is_leaf, namespace=namespace)
2830

29-
return _with_insertion_order(namespace, optree.tree_flatten, tree, is_leaf=is_leaf)
3031

31-
32-
def tree_just_flatten(tree, is_leaf=None, namespace=""):
32+
def tree_just_flatten(tree, is_leaf=None, namespace=DEFAULT_NAMESPACE):
3333
"""Get the leaves of a pytree."""
3434
_register_namespaces()
3535
_check_namespace(namespace)
36-
37-
return _with_insertion_order(namespace, optree.tree_leaves, tree, is_leaf=is_leaf)
36+
with optree.dict_insertion_ordered(True, namespace=namespace):
37+
return optree.tree_leaves(tree, is_leaf=is_leaf, namespace=namespace)
3838

3939

40-
def tree_unflatten(treedef, leaves, namespace=""):
40+
def tree_unflatten(treedef, leaves, namespace=DEFAULT_NAMESPACE):
4141
"""Reconstruct a pytree from the tree definition and the leaves."""
4242
_register_namespaces()
4343

4444
if not isinstance(treedef, PyTreeSpec):
4545
_check_namespace(namespace)
46-
treedef = _with_insertion_order(namespace, optree.tree_structure, treedef)
46+
with optree.dict_insertion_ordered(True, namespace=namespace):
47+
treedef = optree.tree_structure(treedef, namespace=namespace)
4748

48-
# optree.tree_unflatten doesn't need to be wrapped with _with_insertion_order
49-
# because it keeps the insertion order for dictionaries.
49+
# Doesn't need to be wrapped with dict_insertion_ordered
50+
# because it keeps the insertion order for dictionaries by default.
5051
return optree.tree_unflatten(treedef, leaves)
5152

5253

53-
def tree_map(func, tree, is_leaf=None, namespace=""):
54-
"""Map an input function over pytree args to produce a new pytree.
55-
56-
optree.tree_map always respects insertion order for dictionaries and doesn't
57-
require to be wrapped with _with_insertion_order.
58-
"""
54+
def tree_map(func, tree, is_leaf=None, namespace=DEFAULT_NAMESPACE):
55+
"""Map an input function over pytree args to produce a new pytree."""
5956
_register_namespaces()
6057
_check_namespace(namespace)
6158

59+
# Doesn't need to be wrapped with dict_insertion_ordered
60+
# because it keeps the insertion order for dictionaries by default.
6261
return optree.tree_map(func, tree, is_leaf=is_leaf, namespace=namespace)
6362

6463

65-
def leaf_names(tree, is_leaf=None, namespace="", separator="_"):
64+
def leaf_names(tree, is_leaf=None, namespace=DEFAULT_NAMESPACE, separator="_"):
6665
"""Get the path names for tree leaves."""
6766
_register_namespaces()
6867
_check_namespace(namespace)
6968

70-
paths, _, _ = _with_insertion_order(
71-
namespace, optree.tree_flatten_with_path, tree, is_leaf=is_leaf
72-
)
69+
with optree.dict_insertion_ordered(True, namespace=namespace):
70+
paths, _, _ = optree.tree_flatten_with_path(
71+
tree, is_leaf=is_leaf, namespace=namespace
72+
)
7373
return [separator.join(str(p) for p in path) for path in paths]
7474

7575

76-
def _with_insertion_order(namespace, optree_func, *args, **kwargs):
77-
"""Call an optree function, preserving dict key order within a namespace.
78-
79-
By default, optree sorts dictionary keys. When a namespace is provided,
80-
this wrapper enables dict_insertion_ordered mode so that the original
81-
key order is preserved in the output.
82-
"""
83-
if namespace:
84-
with optree.dict_insertion_ordered(True, namespace=namespace):
85-
return optree_func(*args, namespace=namespace, **kwargs)
86-
return optree_func(*args, namespace=namespace, **kwargs)
87-
88-
89-
def tree_equal(tree, other, is_leaf=None, namespace="", equality_checkers=None):
76+
def tree_equal(
77+
tree, other, is_leaf=None, namespace=DEFAULT_NAMESPACE, equality_checkers=None
78+
):
9079
"""Check the equality between two trees."""
9180
equality_checkers = (
9281
_get_equality_checkers()
@@ -136,8 +125,8 @@ def _get_equality_checkers():
136125

137126

138127
def _check_namespace(namespace: str) -> None:
139-
"""Checks if the namespace is a registered and raise a warning."""
140-
if namespace and namespace not in OPTREE_NAMESPACES:
128+
"""Checks if the namespace is registered and raise a warning."""
129+
if namespace != DEFAULT_NAMESPACE and namespace not in OPTREE_NAMESPACES:
141130
warnings.warn(
142131
f"Namespace '{namespace}' is not registered. "
143132
f"Registered namespaces are: {','.join(OPTREE_NAMESPACES)}. "

src/optimagic/typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,12 @@ class MultiStartIterationHistory(TupleLikeAccess):
174174
exploration: IterationHistory | None = None
175175

176176

177+
DEFAULT_NAMESPACE = "optimagic_namespace"
177178
OPTREE_NAMESPACES = (
178179
"value",
179180
"lower_bound",
180181
"upper_bound",
181182
"soft_lower_bound",
182183
"soft_upper_bound",
183184
)
184-
185185
VALUE_NAMESPACE = OPTREE_NAMESPACES[0]

tests/optimagic/parameters/test_tree_registry.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -158,41 +158,38 @@ def test_tree_flatten_and_unflatten_with_None():
158158
assert tree == [None]
159159

160160

161-
def test_dict_insertion_ordering_is_respected_for_registered_namespaces():
161+
@pytest.mark.parametrize("namespace", OPTREE_NAMESPACES)
162+
def test_dict_insertion_ordering_is_respected_for_registered_namespaces(namespace):
162163
params = {"b": [1, 4], "a": [8, 9]}
163-
leaves, _ = tree_flatten(params, namespace=VALUE_NAMESPACE)
164+
leaves, _ = tree_flatten(params, namespace=namespace)
164165
assert leaves == [1, 4, 8, 9]
165-
leaves2 = tree_just_flatten(params, namespace=VALUE_NAMESPACE)
166-
assert leaves2 == [1, 4, 8, 9]
167-
names = leaf_names(params, namespace=VALUE_NAMESPACE)
168-
assert names == ["b_0", "b_1", "a_0", "a_1"]
169166

167+
tree = tree_unflatten(params, [1, 4, 8, 9], namespace=namespace)
168+
assert list(tree.items()) == [("b", [1, 4]), ("a", [8, 9])]
170169

171-
def test_dict_ordering_default_behaviour_is_by_name():
172-
params = {"b": [1, 4], "a": [8, 9]}
173-
leaves, _ = tree_flatten(params)
174-
assert leaves == [8, 9, 1, 4]
170+
leaves2 = tree_just_flatten(params, namespace=namespace)
171+
assert leaves2 == [1, 4, 8, 9]
175172

176-
leaves2 = tree_just_flatten(params)
177-
assert leaves2 == [8, 9, 1, 4]
173+
tree = tree_map(lambda x: x, params, namespace=namespace)
174+
assert list(tree.items()) == [("b", [1, 4]), ("a", [8, 9])]
178175

179-
names = leaf_names(params)
180-
assert names == ["a_0", "a_1", "b_0", "b_1"]
176+
names = leaf_names(params, namespace=namespace)
177+
assert names == ["b_0", "b_1", "a_0", "a_1"]
181178

182179

183-
def test_unflatten_respects_insertion_order():
180+
def test_dict_insertion_ordering_is_respected_for_default_namespace():
184181
params = {"b": [1, 4], "a": [8, 9]}
185-
leaves, treespec = tree_flatten(params)
186-
tree = tree_unflatten(treespec, leaves)
182+
leaves, _ = tree_flatten(params)
183+
assert leaves == [1, 4, 8, 9]
184+
185+
tree = tree_unflatten(params, [1, 4, 8, 9])
187186
assert list(tree.items()) == [("b", [1, 4]), ("a", [8, 9])]
188-
leaves2, treespec2 = tree_flatten(params, namespace=VALUE_NAMESPACE)
189-
tree2 = tree_unflatten(treespec2, leaves2)
190-
assert list(tree2.items()) == [("b", [1, 4]), ("a", [8, 9])]
191187

188+
leaves2 = tree_just_flatten(params)
189+
assert leaves2 == [1, 4, 8, 9]
192190

193-
def test_map_always_respects_insertion_order():
194-
params = {"b": [1, 4], "a": [8, 9]}
195191
tree = tree_map(lambda x: x, params)
196192
assert list(tree.items()) == [("b", [1, 4]), ("a", [8, 9])]
197-
tree2 = tree_map(lambda x: x, params, namespace=VALUE_NAMESPACE)
198-
assert list(tree2.items()) == [("b", [1, 4]), ("a", [8, 9])]
193+
194+
names = leaf_names(params)
195+
assert names == ["b_0", "b_1", "a_0", "a_1"]

0 commit comments

Comments
 (0)