Skip to content

Commit 81fd1ec

Browse files
committed
chore: raise warning for unregistered namespaces
1 parent c85fe08 commit 81fd1ec

2 files changed

Lines changed: 30 additions & 3 deletions

File tree

src/optimagic/parameters/tree_registry.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Wrapper around optree to tailor it to optimagic."""
22

3+
import warnings
34
from functools import partial
45
from itertools import product
56

@@ -23,13 +24,15 @@
2324
def tree_flatten(tree, is_leaf=None, namespace=""):
2425
"""Flatten a pytree."""
2526
_register_namespaces()
27+
_check_namespace(namespace)
2628

2729
return _with_insertion_order(namespace, optree.tree_flatten, tree, is_leaf=is_leaf)
2830

2931

3032
def tree_just_flatten(tree, is_leaf=None, namespace=""):
3133
"""Get the leaves of a pytree."""
3234
_register_namespaces()
35+
_check_namespace(namespace)
3336

3437
return _with_insertion_order(namespace, optree.tree_leaves, tree, is_leaf=is_leaf)
3538

@@ -39,6 +42,7 @@ def tree_unflatten(treedef, leaves, namespace=""):
3942
_register_namespaces()
4043

4144
if not isinstance(treedef, PyTreeSpec):
45+
_check_namespace(namespace)
4246
treedef = _with_insertion_order(namespace, optree.tree_structure, treedef)
4347

4448
# optree.tree_unflatten doesn't need to be wrapped with _with_insertion_order
@@ -53,12 +57,15 @@ def tree_map(func, tree, is_leaf=None, namespace=""):
5357
require to be wrapped with _with_insertion_order.
5458
"""
5559
_register_namespaces()
60+
_check_namespace(namespace)
61+
5662
return optree.tree_map(func, tree, is_leaf=is_leaf, namespace=namespace)
5763

5864

5965
def leaf_names(tree, is_leaf=None, namespace="", separator="_"):
6066
"""Get the path names for tree leaves."""
6167
_register_namespaces()
68+
_check_namespace(namespace)
6269

6370
paths, _, _ = _with_insertion_order(
6471
namespace, optree.tree_flatten_with_path, tree, is_leaf=is_leaf
@@ -128,6 +135,16 @@ def _get_equality_checkers():
128135
return equality_checkers
129136

130137

138+
def _check_namespace(namespace: str) -> None:
139+
"""Checks if the namespace is a registered and raise a warning."""
140+
if namespace and namespace not in OPTREE_NAMESPACES:
141+
warnings.warn(
142+
f"Namespace '{namespace}' is not registered. "
143+
f"Registered namespaces are: {','.join(OPTREE_NAMESPACES)}. "
144+
"Pytree method is being parsed with the default optree namespace."
145+
)
146+
147+
131148
def _register_namespaces() -> None:
132149
"""Register pytree flatten/unflatten methods for each namespace.
133150

tests/optimagic/parameters/test_tree_registry.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,20 @@ def test_tree_methods_with_optimagic_namespace(namespace, bounds_df):
132132
assert_frame_equal(tree, expected)
133133

134134

135-
def test_tree_flatten_with_unregisted_namespace(value_df):
135+
@pytest.mark.parametrize(
136+
"tree_method",
137+
[tree_flatten, tree_just_flatten, leaf_names, tree_map, tree_flatten],
138+
)
139+
def test_tree_methods_raise_warning_with_unregisted_namespace(tree_method, value_df):
136140
"""If namespace is not registered optree method fallbacks to default behaviour."""
137-
leaves, _ = tree_flatten(value_df, namespace="unregistered_namespace")
138-
assert leaves == [value_df]
141+
unregistered_namespace = "unregistered_namespace"
142+
with pytest.warns(match="is not registered."):
143+
if tree_method == tree_map:
144+
_ = tree_map(lambda x: x, value_df, namespace=unregistered_namespace)
145+
elif tree_method == tree_unflatten:
146+
_ = tree_method(value_df, [], namespace=unregistered_namespace)
147+
else:
148+
_ = tree_method(value_df, namespace=unregistered_namespace)
139149

140150

141151
def test_tree_flatten_and_unflatten_with_None():

0 commit comments

Comments
 (0)