33import warnings
44from functools import partial
55from itertools import product
6+ from typing import Any , Callable , Iterable
67
78import numpy as np
89import optree
910import pandas as pd
1011from optree .pytree import PyTreeSpec
1112
12- from optimagic .typing import DEFAULT_NAMESPACE , OPTREE_NAMESPACES
13+ from optimagic .typing import DEFAULT_NAMESPACE , OPTREE_NAMESPACES , PyTree
1314
1415try :
1516 import jax .numpy as jnp # type: ignore[import-not-found]
2223_are_namespaces_registered = False
2324
2425
25- def tree_flatten (tree , is_leaf = None , namespace = DEFAULT_NAMESPACE ):
26+ def tree_flatten (
27+ tree : PyTree ,
28+ is_leaf : Callable [[PyTree ], bool ] | None = None ,
29+ namespace : str = DEFAULT_NAMESPACE ,
30+ ) -> tuple [list , PyTreeSpec ]:
2631 """Flatten a pytree."""
2732 _register_namespaces ()
2833 _check_namespace (namespace )
2934 with optree .dict_insertion_ordered (True , namespace = namespace ):
3035 return optree .tree_flatten (tree , is_leaf = is_leaf , namespace = namespace )
3136
3237
33- def tree_leaves (tree , is_leaf = None , namespace = DEFAULT_NAMESPACE ):
38+ def tree_leaves (
39+ tree : PyTree ,
40+ is_leaf : Callable [[PyTree ], bool ] | None = None ,
41+ namespace : str = DEFAULT_NAMESPACE ,
42+ ) -> list :
3443 """Get the leaves of a pytree."""
3544 _register_namespaces ()
3645 _check_namespace (namespace )
3746 with optree .dict_insertion_ordered (True , namespace = namespace ):
3847 return optree .tree_leaves (tree , is_leaf = is_leaf , namespace = namespace )
3948
4049
41- def tree_unflatten (treedef , leaves , namespace = DEFAULT_NAMESPACE ):
50+ def tree_unflatten (
51+ treedef : PyTree | PyTreeSpec ,
52+ leaves : Iterable ,
53+ namespace : str = DEFAULT_NAMESPACE ,
54+ ) -> PyTree :
4255 """Reconstruct a pytree from the tree definition and the leaves."""
4356 _register_namespaces ()
4457
@@ -52,7 +65,12 @@ def tree_unflatten(treedef, leaves, namespace=DEFAULT_NAMESPACE):
5265 return optree .tree_unflatten (treedef , leaves )
5366
5467
55- def tree_map (func , tree , is_leaf = None , namespace = DEFAULT_NAMESPACE ):
68+ def tree_map (
69+ func : Callable [[PyTree ], PyTree ],
70+ tree : PyTree ,
71+ is_leaf : Callable [[PyTree ], bool ] | None = None ,
72+ namespace : str = DEFAULT_NAMESPACE ,
73+ ) -> PyTree :
5674 """Map an input function over pytree args to produce a new pytree."""
5775 _register_namespaces ()
5876 _check_namespace (namespace )
@@ -62,7 +80,12 @@ def tree_map(func, tree, is_leaf=None, namespace=DEFAULT_NAMESPACE):
6280 return optree .tree_map (func , tree , is_leaf = is_leaf , namespace = namespace )
6381
6482
65- def leaf_names (tree , is_leaf = None , namespace = DEFAULT_NAMESPACE , separator = "_" ):
83+ def leaf_names (
84+ tree : PyTree ,
85+ is_leaf : Callable [[PyTree ], bool ] | None = None ,
86+ namespace : str = DEFAULT_NAMESPACE ,
87+ separator : str = "_" ,
88+ ) -> list [str ]:
6689 """Get the path names for tree leaves."""
6790 _register_namespaces ()
6891 _check_namespace (namespace )
@@ -75,8 +98,12 @@ def leaf_names(tree, is_leaf=None, namespace=DEFAULT_NAMESPACE, separator="_"):
7598
7699
77100def tree_equal (
78- tree , other , is_leaf = None , namespace = DEFAULT_NAMESPACE , equality_checkers = None
79- ):
101+ tree : PyTree ,
102+ other : PyTree ,
103+ is_leaf : Callable [[PyTree ], bool ] | None = None ,
104+ namespace : str = DEFAULT_NAMESPACE ,
105+ equality_checkers : dict [str , Callable [[Any , Any ], bool ]] | None = None ,
106+ ) -> bool :
80107 """Check the equality between two trees."""
81108 equality_checkers = (
82109 _get_equality_checkers ()
@@ -238,7 +265,8 @@ def _unflatten_jax_array(aux_data, leaves):
238265 return jnp .array (leaves ).reshape (aux_data )
239266
240267
241- def _get_df_names (df ):
268+ def _get_df_names (df : pd .DataFrame ) -> list [str ]:
269+ """Get string names for dataframe leaf paths."""
242270 index_strings = list (df .index .map (_index_element_to_string ))
243271 if "value" in df :
244272 out = index_strings
@@ -248,7 +276,8 @@ def _get_df_names(df):
248276 return out
249277
250278
251- def _index_element_to_string (element ):
279+ def _index_element_to_string (element : Any ) -> str :
280+ """Convert an index element to its string representation."""
252281 if isinstance (element , (tuple , list )):
253282 as_strings = [str (entry ) for entry in element ]
254283 res_string = "_" .join (as_strings )
@@ -258,7 +287,8 @@ def _index_element_to_string(element):
258287 return res_string
259288
260289
261- def _array_element_names (arr ):
290+ def _array_element_names (arr : np .ndarray ) -> list [str ]:
291+ """Get string names for array like element leaf paths."""
262292 dim_names = [map (str , range (n )) for n in arr .shape ]
263293 names = list (map ("_" .join , product (* dim_names )))
264294 return names
0 commit comments