11"""Wrapper around optree to tailor it to optimagic."""
22
3+ import warnings
34from functools import partial
45from itertools import product
56
2324def tree_flatten (tree , is_leaf = None , namespace = "" ):
2425 """Flatten a pytree."""
2526 _register_namespaces ()
27+ _check_namespace (namespace )
2628
2729 return _with_insertion_order (namespace , optree .tree_flatten , tree , is_leaf = is_leaf )
2830
2931
3032def tree_just_flatten (tree , is_leaf = None , namespace = "" ):
3133 """Get the leaves of a pytree."""
3234 _register_namespaces ()
35+ _check_namespace (namespace )
3336
3437 return _with_insertion_order (namespace , optree .tree_leaves , tree , is_leaf = is_leaf )
3538
@@ -39,6 +42,7 @@ def tree_unflatten(treedef, leaves, namespace=""):
3942 _register_namespaces ()
4043
4144 if not isinstance (treedef , PyTreeSpec ):
45+ _check_namespace (namespace )
4246 treedef = _with_insertion_order (namespace , optree .tree_structure , treedef )
4347
4448 # optree.tree_unflatten doesn't need to be wrapped with _with_insertion_order
@@ -53,12 +57,15 @@ def tree_map(func, tree, is_leaf=None, namespace=""):
5357 require to be wrapped with _with_insertion_order.
5458 """
5559 _register_namespaces ()
60+ _check_namespace (namespace )
61+
5662 return optree .tree_map (func , tree , is_leaf = is_leaf , namespace = namespace )
5763
5864
5965def leaf_names (tree , is_leaf = None , namespace = "" , separator = "_" ):
6066 """Get the path names for tree leaves."""
6167 _register_namespaces ()
68+ _check_namespace (namespace )
6269
6370 paths , _ , _ = _with_insertion_order (
6471 namespace , optree .tree_flatten_with_path , tree , is_leaf = is_leaf
@@ -128,6 +135,16 @@ def _get_equality_checkers():
128135 return equality_checkers
129136
130137
138+ def _check_namespace (namespace : str ) -> None :
139+ """Checks if the namespace is a registered and raise a warning."""
140+ if namespace and namespace not in OPTREE_NAMESPACES :
141+ warnings .warn (
142+ f"Namespace '{ namespace } ' is not registered. "
143+ f"Registered namespaces are: { ',' .join (OPTREE_NAMESPACES )} . "
144+ "Pytree method is being parsed with the default optree namespace."
145+ )
146+
147+
131148def _register_namespaces () -> None :
132149 """Register pytree flatten/unflatten methods for each namespace.
133150
0 commit comments